-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
677 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,3 +162,5 @@ cython_debug/ | |
#.idea/ | ||
|
||
.DS_Store | ||
|
||
.depthai_cached_models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import depthai as dai | ||
from utils.arguments import initialize_argparser, parse_model_slug | ||
from utils.model import get_input_shape, get_model_from_hub, get_parser, setup_parser | ||
from visualization.visualize import visualize | ||
|
||
# Initialize the argument parser | ||
arg_parser, args = initialize_argparser() | ||
|
||
# Parse the model slug | ||
model_slug, model_version_slug = parse_model_slug(args) | ||
|
||
# Get the model from the HubAI | ||
nn_archive = get_model_from_hub(model_slug, model_version_slug) | ||
|
||
# Get the parser | ||
parser_class, parser_name = get_parser(nn_archive) | ||
input_shape = get_input_shape(nn_archive) | ||
|
||
if parser_name == "XFeatParser": | ||
raise NotImplementedError("XFeatParser is not supported in this script yet.") | ||
|
||
# Create the pipeline | ||
with dai.Pipeline() as pipeline: | ||
cam = pipeline.create(dai.node.Camera).build() | ||
|
||
# YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser | ||
if parser_name == "YOLO" or parser_name == "SSD": | ||
network = pipeline.create(dai.node.DetectionNetwork).build( | ||
cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p), nn_archive | ||
) | ||
parser_queue = network.out.createOutputQueue() | ||
else: | ||
image_type = dai.ImgFrame.Type.BGR888p | ||
if "gray" in model_version_slug: | ||
image_type = dai.ImgFrame.Type.GRAY8 | ||
|
||
if input_shape[0] < 128 or input_shape[1] < 128: | ||
print( | ||
"Input shape is too small so we are requesting a larger image and resizing it." | ||
) | ||
print( | ||
"During visualization we resize small image back to large, so image quality is lower." | ||
) | ||
manip = pipeline.create(dai.node.ImageManip) | ||
manip.initialConfig.setResize(input_shape) | ||
large_input_shape = (input_shape[0] * 4, input_shape[1] * 4) | ||
cam.requestOutput(large_input_shape, type=image_type).link(manip.inputImage) | ||
network = pipeline.create(dai.node.NeuralNetwork).build( | ||
manip.out, nn_archive | ||
) | ||
else: | ||
network = pipeline.create(dai.node.NeuralNetwork).build( | ||
cam.requestOutput(input_shape, type=image_type), nn_archive | ||
) | ||
|
||
parser = pipeline.create(parser_class) | ||
setup_parser(parser, nn_archive, parser_name) | ||
|
||
# Linking | ||
network.out.link(parser.input) | ||
|
||
parser_queue = parser.out.createOutputQueue() | ||
|
||
camera_queue = network.passthrough.createOutputQueue() | ||
|
||
pipeline.start() | ||
|
||
while pipeline.isRunning(): | ||
frame: dai.ImgFrame = camera_queue.get().getCvFrame() | ||
message = parser_queue.get() | ||
|
||
extraParams = ( | ||
nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams | ||
) | ||
if visualize(frame, message, parser_name, extraParams): | ||
pipeline.stop() | ||
break |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import argparse | ||
from typing import Tuple | ||
|
||
|
||
def initialize_argparser(): | ||
"""Initialize the argument parser for the script.""" | ||
parser = argparse.ArgumentParser() | ||
parser.description = "General example script to run any model available in HubAI on DepthAI device. \ | ||
All you need is a model slug of the model and the script will download the model from HubAI and create \ | ||
the whole pipeline with visualizations. You also need a DepthAI device connected to your computer. \ | ||
Currently, only RVC2 devices are supported." | ||
|
||
parser.add_argument( | ||
"-s", | ||
"--model_slug", | ||
help="slug of the model in HubAI.", | ||
required=True, | ||
type=str, | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
return parser, args | ||
|
||
|
||
def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]: | ||
"""Parse the model slug from the arguments. | ||
Returns the model slug and model version slug. | ||
""" | ||
model_slug = args.model_slug | ||
|
||
# parse the model slug | ||
if ":" not in model_slug: | ||
raise NameError( | ||
"Please provide the model slug in the format of 'model_slug:model_version_slug'" | ||
) | ||
|
||
model_slug_parts = model_slug.split(":") | ||
model_slug = model_slug_parts[0] | ||
model_version_slug = model_slug_parts[1] | ||
|
||
return model_slug, model_version_slug |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from typing import List, Tuple | ||
|
||
import depthai as dai | ||
|
||
from depthai_nodes.ml.parsers import ( | ||
ClassificationParser, | ||
KeypointParser, | ||
MonocularDepthParser, | ||
SCRFDParser, | ||
SegmentationParser, | ||
XFeatParser, | ||
) | ||
|
||
|
||
def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive: | ||
"""Get the model from the HubAI and return the NN archive.""" | ||
print( | ||
f"Downloading model {model_slug} with version {model_version_slug} from HubAI..." | ||
) | ||
modelDescription = dai.NNModelDescription( | ||
modelSlug=model_slug, modelVersionSlug=model_version_slug, platform="RVC2" | ||
) | ||
archivePath = dai.getModelFromZoo(modelDescription) | ||
print("Download successful!") | ||
nn_archive = dai.NNArchive(archivePath) | ||
|
||
return nn_archive | ||
|
||
|
||
def get_parser_from_archive(nn_archive: dai.NNArchive) -> str: | ||
"""Get the required parser from the NN archive.""" | ||
try: | ||
required_parser = nn_archive.getConfig().getConfigV1().model.heads[0].parser | ||
except AttributeError: | ||
print( | ||
"This NN archive does not have a parser. Please use NN archives that have parsers." | ||
) | ||
exit(1) | ||
|
||
print(f"Required parser: {required_parser}") | ||
|
||
return required_parser | ||
|
||
|
||
def get_parser(nn_archive: dai.NNArchive) -> Tuple[dai.ThreadedNode, str]: | ||
"""Map the parser from the NN archive to the actual parser in depthai-nodes.""" | ||
required_parser = get_parser_from_archive(nn_archive) | ||
|
||
if required_parser == "YOLO" or required_parser == "SSD": | ||
return None, required_parser | ||
|
||
parser = globals().get(required_parser, None) | ||
|
||
if parser is None: | ||
raise NameError( | ||
f"Parser {required_parser} is not available in the depthai_nodes.ml.parsers module." | ||
) | ||
|
||
return parser, required_parser | ||
|
||
|
||
def get_inputs_from_archive(nn_archive: dai.NNArchive) -> List: | ||
"""Get all inputs from NN archive.""" | ||
try: | ||
inputs = nn_archive.getConfig().getConfigV1().model.inputs | ||
except AttributeError: | ||
print( | ||
"This NN archive does not have an input shape. Please use NN archives that have input shapes." | ||
) | ||
exit(1) | ||
|
||
return inputs | ||
|
||
|
||
def get_input_shape(nn_archive: dai.NNArchive) -> Tuple[int, int]: | ||
"""Get the input shape of the model from the NN archive.""" | ||
inputs = get_inputs_from_archive(nn_archive) | ||
|
||
if len(inputs) > 1: | ||
raise ValueError( | ||
"This model has more than one input. Currently, only models with one input are supported." | ||
) | ||
|
||
try: | ||
input_shape = inputs[0].shape[2:][::-1] | ||
except AttributeError: | ||
print( | ||
"This NN archive does not have an input shape. Please use NN archives that have input shapes." | ||
) | ||
exit(1) | ||
|
||
print(f"Input shape: {input_shape}") | ||
|
||
return input_shape | ||
|
||
|
||
def setup_scrfd_parser(parser: SCRFDParser, params: dict): | ||
"""Setup the SCRFD parser with the required metadata.""" | ||
try: | ||
num_anchors = params["num_anchors"] | ||
feat_stride_fpn = params["feat_stride_fpn"] | ||
parser.setNumAnchors(num_anchors) | ||
parser.setFeatStrideFPN(feat_stride_fpn) | ||
except Exception: | ||
print( | ||
"This NN archive does not have required metadata for SCRFDParser. Skipping setup..." | ||
) | ||
|
||
|
||
def setup_segmentation_parser(parser: SegmentationParser, params: dict): | ||
"""Setup the segmentation parser with the required metadata.""" | ||
try: | ||
background_class = params["background_class"] | ||
parser.setBackgroundClass(background_class) | ||
except Exception: | ||
print( | ||
"This NN archive does not have required metadata for SegmentationParser. Skipping setup..." | ||
) | ||
|
||
|
||
def setup_keypoint_parser(parser: KeypointParser, params: dict): | ||
"""Setup the keypoint parser with the required metadata.""" | ||
try: | ||
num_keypoints = params["n_keypoints"] | ||
scale_factor = params["scale_factor"] | ||
parser.setNumKeypoints(num_keypoints) | ||
parser.setScaleFactor(scale_factor) | ||
except Exception: | ||
print( | ||
"This NN archive does not have required metadata for KeypointParser. Skipping setup..." | ||
) | ||
|
||
|
||
def setup_classification_parser(parser: ClassificationParser, params: dict): | ||
"""Setup the classification parser with the required metadata.""" | ||
try: | ||
classes = params["classes"] | ||
is_softmax = params["is_softmax"] | ||
parser.setClasses(classes) | ||
parser.setSoftmax(is_softmax) | ||
except Exception: | ||
print( | ||
"This NN archive does not have required metadata for ClassificationParser. Skipping setup..." | ||
) | ||
|
||
|
||
def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict): | ||
"""Setup the monocular depth parser with the required metadata.""" | ||
try: | ||
depth_type = params["depth_type"] | ||
if depth_type == "relative": | ||
parser.setRelativeDepthType() | ||
else: | ||
parser.setMetricDepthType() | ||
except Exception: | ||
print( | ||
"This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..." | ||
) | ||
|
||
|
||
def setup_xfeat_parser(parser: XFeatParser, params: dict): | ||
"""Setup the XFeat parser with the required metadata.""" | ||
try: | ||
input_size = params["input_size"] | ||
parser.setInputSize(input_size) | ||
parser.setOriginalSize(input_size) | ||
except Exception: | ||
print( | ||
"This NN archive does not have required metadata for XFeatParser. Skipping setup..." | ||
) | ||
|
||
|
||
def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str): | ||
"""Setup the parser with the NN archive.""" | ||
|
||
extraParams = ( | ||
nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams | ||
) | ||
|
||
if parser_name == "SCRFDParser": | ||
setup_scrfd_parser(parser, extraParams) | ||
elif parser_name == "SegmentationParser": | ||
setup_segmentation_parser(parser, extraParams) | ||
elif parser_name == "KeypointParser": | ||
setup_keypoint_parser(parser, extraParams) | ||
elif parser_name == "ClassificationParser": | ||
setup_classification_parser(parser, extraParams) | ||
elif parser_name == "MonocularDepthParser": | ||
setup_monocular_depth_parser(parser, extraParams) | ||
elif parser_name == "XFeatParser": | ||
setup_xfeat_parser(parser, extraParams) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import cv2 | ||
import depthai as dai | ||
|
||
from depthai_nodes.ml.messages import AgeGender, Classifications | ||
|
||
from .messages import parse_classification_message, parser_age_gender_message | ||
|
||
|
||
def visualize_classification( | ||
frame: dai.ImgFrame, message: Classifications, extraParams: dict | ||
): | ||
"""Visualizes the classification on the frame.""" | ||
classes, scores = parse_classification_message(message) | ||
classes = classes[:2] | ||
scores = scores[:2] | ||
if frame.shape[0] < 128: | ||
frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) | ||
for i, (cls, score) in enumerate(zip(classes, scores)): | ||
cv2.putText( | ||
frame, | ||
f"{cls}: {score:.2f}", | ||
(10, 20 + 20 * i), | ||
cv2.FONT_HERSHEY_TRIPLEX, | ||
0.5, | ||
255, | ||
) | ||
|
||
cv2.imshow("Classification", frame) | ||
if cv2.waitKey(1) == ord("q"): | ||
cv2.destroyAllWindows() | ||
return True | ||
|
||
return False | ||
|
||
|
||
def visualize_age_gender(frame: dai.ImgFrame, message: AgeGender, extraParams: dict): | ||
"""Visualizes the age and predicted gender on the frame.""" | ||
if frame.shape[0] < 128: | ||
frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) | ||
age, gender_classes, gender_scores = parser_age_gender_message(message) | ||
cv2.putText(frame, f"Age: {age}", (10, 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) | ||
for i, (cls, score) in enumerate(zip(gender_classes, gender_scores)): | ||
cv2.putText( | ||
frame, | ||
f"{cls}: {score:.2f}", | ||
(10, 40 + 20 * i), | ||
cv2.FONT_HERSHEY_TRIPLEX, | ||
0.5, | ||
255, | ||
) | ||
|
||
cv2.imshow("Age-gender", frame) | ||
if cv2.waitKey(1) == ord("q"): | ||
cv2.destroyAllWindows() | ||
return True | ||
|
||
return False |
Oops, something went wrong.