From 1c353dd6e23ab9901dc6ccae8840ebb73f9cb52d Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:34:42 +0200 Subject: [PATCH] Examples added. --- .gitignore | 2 + examples/main.py | 77 +++++++++ examples/utils/__init__.py | 0 examples/utils/arguments.py | 43 +++++ examples/utils/model.py | 191 +++++++++++++++++++++++ examples/visualization/__init__.py | 0 examples/visualization/classification.py | 57 +++++++ examples/visualization/colors.py | 39 +++++ examples/visualization/detection.py | 102 ++++++++++++ examples/visualization/image.py | 15 ++ examples/visualization/keypoints.py | 23 +++ examples/visualization/mapping.py | 23 +++ examples/visualization/messages.py | 52 ++++++ examples/visualization/segmentation.py | 41 +++++ examples/visualization/visualize.py | 12 ++ 15 files changed, 677 insertions(+) create mode 100644 examples/main.py create mode 100644 examples/utils/__init__.py create mode 100644 examples/utils/arguments.py create mode 100644 examples/utils/model.py create mode 100644 examples/visualization/__init__.py create mode 100644 examples/visualization/classification.py create mode 100644 examples/visualization/colors.py create mode 100644 examples/visualization/detection.py create mode 100644 examples/visualization/image.py create mode 100644 examples/visualization/keypoints.py create mode 100644 examples/visualization/mapping.py create mode 100644 examples/visualization/messages.py create mode 100644 examples/visualization/segmentation.py create mode 100644 examples/visualization/visualize.py diff --git a/.gitignore b/.gitignore index c772aba7..2952c645 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,5 @@ cython_debug/ #.idea/ .DS_Store + +.depthai_cached_models/ diff --git a/examples/main.py b/examples/main.py new file mode 100644 index 00000000..4428e6c9 --- /dev/null +++ b/examples/main.py @@ -0,0 +1,77 @@ +import depthai as dai +from utils.arguments import initialize_argparser, parse_model_slug +from utils.model import get_input_shape, get_model_from_hub, get_parser, setup_parser +from visualization.visualize import visualize + +# Initialize the argument parser +arg_parser, args = initialize_argparser() + +# Parse the model slug +model_slug, model_version_slug = parse_model_slug(args) + +# Get the model from the HubAI +nn_archive = get_model_from_hub(model_slug, model_version_slug) + +# Get the parser +parser_class, parser_name = get_parser(nn_archive) +input_shape = get_input_shape(nn_archive) + +if parser_name == "XFeatParser": + raise NotImplementedError("XFeatParser is not supported in this script yet.") + +# Create the pipeline +with dai.Pipeline() as pipeline: + cam = pipeline.create(dai.node.Camera).build() + + # YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser + if parser_name == "YOLO" or parser_name == "SSD": + network = pipeline.create(dai.node.DetectionNetwork).build( + cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p), nn_archive + ) + parser_queue = network.out.createOutputQueue() + else: + image_type = dai.ImgFrame.Type.BGR888p + if "gray" in model_version_slug: + image_type = dai.ImgFrame.Type.GRAY8 + + if input_shape[0] < 128 or input_shape[1] < 128: + print( + "Input shape is too small so we are requesting a larger image and resizing it." + ) + print( + "During visualization we resize small image back to large, so image quality is lower." + ) + manip = pipeline.create(dai.node.ImageManip) + manip.initialConfig.setResize(input_shape) + large_input_shape = (input_shape[0] * 4, input_shape[1] * 4) + cam.requestOutput(large_input_shape, type=image_type).link(manip.inputImage) + network = pipeline.create(dai.node.NeuralNetwork).build( + manip.out, nn_archive + ) + else: + network = pipeline.create(dai.node.NeuralNetwork).build( + cam.requestOutput(input_shape, type=image_type), nn_archive + ) + + parser = pipeline.create(parser_class) + setup_parser(parser, nn_archive, parser_name) + + # Linking + network.out.link(parser.input) + + parser_queue = parser.out.createOutputQueue() + + camera_queue = network.passthrough.createOutputQueue() + + pipeline.start() + + while pipeline.isRunning(): + frame: dai.ImgFrame = camera_queue.get().getCvFrame() + message = parser_queue.get() + + extraParams = ( + nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams + ) + if visualize(frame, message, parser_name, extraParams): + pipeline.stop() + break diff --git a/examples/utils/__init__.py b/examples/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/utils/arguments.py b/examples/utils/arguments.py new file mode 100644 index 00000000..3551bf5c --- /dev/null +++ b/examples/utils/arguments.py @@ -0,0 +1,43 @@ +import argparse +from typing import Tuple + + +def initialize_argparser(): + """Initialize the argument parser for the script.""" + parser = argparse.ArgumentParser() + parser.description = "General example script to run any model available in HubAI on DepthAI device. \ + All you need is a model slug of the model and the script will download the model from HubAI and create \ + the whole pipeline with visualizations. You also need a DepthAI device connected to your computer. \ + Currently, only RVC2 devices are supported." + + parser.add_argument( + "-s", + "--model_slug", + help="slug of the model in HubAI.", + required=True, + type=str, + ) + + args = parser.parse_args() + + return parser, args + + +def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]: + """Parse the model slug from the arguments. + + Returns the model slug and model version slug. + """ + model_slug = args.model_slug + + # parse the model slug + if ":" not in model_slug: + raise NameError( + "Please provide the model slug in the format of 'model_slug:model_version_slug'" + ) + + model_slug_parts = model_slug.split(":") + model_slug = model_slug_parts[0] + model_version_slug = model_slug_parts[1] + + return model_slug, model_version_slug diff --git a/examples/utils/model.py b/examples/utils/model.py new file mode 100644 index 00000000..67a7863c --- /dev/null +++ b/examples/utils/model.py @@ -0,0 +1,191 @@ +from typing import List, Tuple + +import depthai as dai + +from depthai_nodes.ml.parsers import ( + ClassificationParser, + KeypointParser, + MonocularDepthParser, + SCRFDParser, + SegmentationParser, + XFeatParser, +) + + +def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive: + """Get the model from the HubAI and return the NN archive.""" + print( + f"Downloading model {model_slug} with version {model_version_slug} from HubAI..." + ) + modelDescription = dai.NNModelDescription( + modelSlug=model_slug, modelVersionSlug=model_version_slug, platform="RVC2" + ) + archivePath = dai.getModelFromZoo(modelDescription) + print("Download successful!") + nn_archive = dai.NNArchive(archivePath) + + return nn_archive + + +def get_parser_from_archive(nn_archive: dai.NNArchive) -> str: + """Get the required parser from the NN archive.""" + try: + required_parser = nn_archive.getConfig().getConfigV1().model.heads[0].parser + except AttributeError: + print( + "This NN archive does not have a parser. Please use NN archives that have parsers." + ) + exit(1) + + print(f"Required parser: {required_parser}") + + return required_parser + + +def get_parser(nn_archive: dai.NNArchive) -> Tuple[dai.ThreadedNode, str]: + """Map the parser from the NN archive to the actual parser in depthai-nodes.""" + required_parser = get_parser_from_archive(nn_archive) + + if required_parser == "YOLO" or required_parser == "SSD": + return None, required_parser + + parser = globals().get(required_parser, None) + + if parser is None: + raise NameError( + f"Parser {required_parser} is not available in the depthai_nodes.ml.parsers module." + ) + + return parser, required_parser + + +def get_inputs_from_archive(nn_archive: dai.NNArchive) -> List: + """Get all inputs from NN archive.""" + try: + inputs = nn_archive.getConfig().getConfigV1().model.inputs + except AttributeError: + print( + "This NN archive does not have an input shape. Please use NN archives that have input shapes." + ) + exit(1) + + return inputs + + +def get_input_shape(nn_archive: dai.NNArchive) -> Tuple[int, int]: + """Get the input shape of the model from the NN archive.""" + inputs = get_inputs_from_archive(nn_archive) + + if len(inputs) > 1: + raise ValueError( + "This model has more than one input. Currently, only models with one input are supported." + ) + + try: + input_shape = inputs[0].shape[2:][::-1] + except AttributeError: + print( + "This NN archive does not have an input shape. Please use NN archives that have input shapes." + ) + exit(1) + + print(f"Input shape: {input_shape}") + + return input_shape + + +def setup_scrfd_parser(parser: SCRFDParser, params: dict): + """Setup the SCRFD parser with the required metadata.""" + try: + num_anchors = params["num_anchors"] + feat_stride_fpn = params["feat_stride_fpn"] + parser.setNumAnchors(num_anchors) + parser.setFeatStrideFPN(feat_stride_fpn) + except Exception: + print( + "This NN archive does not have required metadata for SCRFDParser. Skipping setup..." + ) + + +def setup_segmentation_parser(parser: SegmentationParser, params: dict): + """Setup the segmentation parser with the required metadata.""" + try: + background_class = params["background_class"] + parser.setBackgroundClass(background_class) + except Exception: + print( + "This NN archive does not have required metadata for SegmentationParser. Skipping setup..." + ) + + +def setup_keypoint_parser(parser: KeypointParser, params: dict): + """Setup the keypoint parser with the required metadata.""" + try: + num_keypoints = params["n_keypoints"] + scale_factor = params["scale_factor"] + parser.setNumKeypoints(num_keypoints) + parser.setScaleFactor(scale_factor) + except Exception: + print( + "This NN archive does not have required metadata for KeypointParser. Skipping setup..." + ) + + +def setup_classification_parser(parser: ClassificationParser, params: dict): + """Setup the classification parser with the required metadata.""" + try: + classes = params["classes"] + is_softmax = params["is_softmax"] + parser.setClasses(classes) + parser.setSoftmax(is_softmax) + except Exception: + print( + "This NN archive does not have required metadata for ClassificationParser. Skipping setup..." + ) + + +def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict): + """Setup the monocular depth parser with the required metadata.""" + try: + depth_type = params["depth_type"] + if depth_type == "relative": + parser.setRelativeDepthType() + else: + parser.setMetricDepthType() + except Exception: + print( + "This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..." + ) + + +def setup_xfeat_parser(parser: XFeatParser, params: dict): + """Setup the XFeat parser with the required metadata.""" + try: + input_size = params["input_size"] + parser.setInputSize(input_size) + parser.setOriginalSize(input_size) + except Exception: + print( + "This NN archive does not have required metadata for XFeatParser. Skipping setup..." + ) + + +def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str): + """Setup the parser with the NN archive.""" + + extraParams = ( + nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams + ) + + if parser_name == "SCRFDParser": + setup_scrfd_parser(parser, extraParams) + elif parser_name == "SegmentationParser": + setup_segmentation_parser(parser, extraParams) + elif parser_name == "KeypointParser": + setup_keypoint_parser(parser, extraParams) + elif parser_name == "ClassificationParser": + setup_classification_parser(parser, extraParams) + elif parser_name == "MonocularDepthParser": + setup_monocular_depth_parser(parser, extraParams) + elif parser_name == "XFeatParser": + setup_xfeat_parser(parser, extraParams) diff --git a/examples/visualization/__init__.py b/examples/visualization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/visualization/classification.py b/examples/visualization/classification.py new file mode 100644 index 00000000..90c3bb04 --- /dev/null +++ b/examples/visualization/classification.py @@ -0,0 +1,57 @@ +import cv2 +import depthai as dai + +from depthai_nodes.ml.messages import AgeGender, Classifications + +from .messages import parse_classification_message, parser_age_gender_message + + +def visualize_classification( + frame: dai.ImgFrame, message: Classifications, extraParams: dict +): + """Visualizes the classification on the frame.""" + classes, scores = parse_classification_message(message) + classes = classes[:2] + scores = scores[:2] + if frame.shape[0] < 128: + frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) + for i, (cls, score) in enumerate(zip(classes, scores)): + cv2.putText( + frame, + f"{cls}: {score:.2f}", + (10, 20 + 20 * i), + cv2.FONT_HERSHEY_TRIPLEX, + 0.5, + 255, + ) + + cv2.imshow("Classification", frame) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False + + +def visualize_age_gender(frame: dai.ImgFrame, message: AgeGender, extraParams: dict): + """Visualizes the age and predicted gender on the frame.""" + if frame.shape[0] < 128: + frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) + age, gender_classes, gender_scores = parser_age_gender_message(message) + cv2.putText(frame, f"Age: {age}", (10, 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) + for i, (cls, score) in enumerate(zip(gender_classes, gender_scores)): + cv2.putText( + frame, + f"{cls}: {score:.2f}", + (10, 40 + 20 * i), + cv2.FONT_HERSHEY_TRIPLEX, + 0.5, + 255, + ) + + cv2.imshow("Age-gender", frame) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/colors.py b/examples/visualization/colors.py new file mode 100644 index 00000000..995ada6c --- /dev/null +++ b/examples/visualization/colors.py @@ -0,0 +1,39 @@ +def get_adas_colors(): + colors = [ + (0, 0, 0), # class 0 - black + (128, 0, 0), # class 1 - maroon + (0, 128, 0), # class 2 - green + (128, 128, 0), # class 3 - olive + (0, 0, 128), # class 4 - navy + (128, 0, 128), # class 5 - purple + (0, 128, 128), # class 6 - teal + (128, 128, 128), # class 7 - gray + (64, 0, 0), # class 8 - maroon + (192, 0, 0), # class 9 - red + (64, 128, 0), # class 10 - olive + (192, 128, 0), # class 11 - yellow + (64, 0, 128), # class 12 - navy + (192, 0, 128), # class 13 - fuchsia + (64, 128, 128), # class 14 - aqua + (192, 128, 128), # class 15 - silver + (0, 64, 0), # class 16 - green + (128, 64, 0), # class 17 - orange + (0, 192, 0), # class 18 - lime + (128, 192, 0), # class 19 - yellow + (0, 64, 128), # class 20 - blue + ] + return colors + + +def get_selfie_colors(): + colors = [(0, 0, 0), (0, 255, 0)] + return colors + + +def get_ewasr_colors(): + colors = [ + (0, 255, 255), # class 0 - black + (255, 255, 0), # class 1 - maroon + (0, 0, 0), # class 2 - green + ] + return colors diff --git a/examples/visualization/detection.py b/examples/visualization/detection.py new file mode 100644 index 00000000..a3e5c534 --- /dev/null +++ b/examples/visualization/detection.py @@ -0,0 +1,102 @@ +import cv2 +import depthai as dai + +from depthai_nodes.ml.messages import Lines + +from .messages import parse_detection_message, parse_line_detection_message + + +def visualize_detections( + frame: dai.ImgFrame, message: dai.ImgDetections, extraParams: dict +): + """Visualizes the detections on the frame. + + Also, checks if there are any keypoints available to visualize. + """ + labels = extraParams.get("classes", None) + detections = parse_detection_message(message) + for detection in detections: + xmin, ymin, xmax, ymax = ( + detection.xmin, + detection.ymin, + detection.xmax, + detection.ymax, + ) + if xmin > 1 or ymin > 1 or xmax > 1 or ymax > 1: + xmin = int(xmin) + ymin = int(ymin) + xmax = int(xmax) + ymax = int(ymax) + else: + xmin = int(xmin * frame.shape[1]) + ymin = int(ymin * frame.shape[0]) + xmax = int(xmax * frame.shape[1]) + ymax = int(ymax * frame.shape[0]) + cv2.rectangle( + frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 2 + ) + + try: + keypoints = detection.keypoints + for kp in keypoints: + cv2.circle( + frame, + (int(kp[0] * frame.shape[1]), int(kp[1] * frame.shape[0])), + 5, + (0, 0, 255), + -1, + ) + except Exception: + print("No keypoints available.") + + cv2.putText( + frame, + f"{detection.confidence * 100:.2f}%", + (int(xmin) + 10, int(ymin) + 20), + cv2.FONT_HERSHEY_TRIPLEX, + 0.5, + 255, + ) + if labels is not None: + cv2.putText( + frame, + labels[detection.label], + (int(xmin) + 10, int(ymin) + 40), + cv2.FONT_HERSHEY_TRIPLEX, + 0.5, + 255, + ) + + cv2.imshow("Detections", frame) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False + + +def visualize_line_detections(frame: dai.ImgFrame, message: Lines, extraParams: dict): + """Visualizes the lines on the frame.""" + lines = parse_line_detection_message(message) + h, w = frame.shape[:2] + for line in lines: + x1 = line.start_point.x * w + y1 = line.start_point.y * h + x2 = line.end_point.x * w + y2 = line.end_point.y * h + cv2.line(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 3, 16) + + cv2.putText( + frame, + f"Number of lines: {len(lines)}", + (2, frame.shape[0] - 4), + cv2.FONT_HERSHEY_COMPLEX, + 0.5, + (255, 0, 0), + ) + cv2.imshow("Lines", frame) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/image.py b/examples/visualization/image.py new file mode 100644 index 00000000..19255cbb --- /dev/null +++ b/examples/visualization/image.py @@ -0,0 +1,15 @@ +import cv2 +import depthai as dai + +from .messages import parse_image_message + + +def visualize_image(frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict): + """Visualizes the image on the frame.""" + image = parse_image_message(message) + cv2.imshow("Image", image) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/keypoints.py b/examples/visualization/keypoints.py new file mode 100644 index 00000000..a587c2e2 --- /dev/null +++ b/examples/visualization/keypoints.py @@ -0,0 +1,23 @@ +import cv2 +import depthai as dai + +from depthai_nodes.ml.messages import Keypoints + +from .messages import parse_keypoints_message + + +def visualize_keypoints(frame: dai.ImgFrame, message: Keypoints, extraParams: dict): + """Visualizes the keypoints on the frame.""" + keypoints = parse_keypoints_message(message) + + for kp in keypoints: + x = int(kp.x * frame.shape[1]) + y = int(kp.y * frame.shape[0]) + cv2.circle(frame, (x, y), 1, (0, 255, 0), -1) + + cv2.imshow("Keypoints", frame) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/mapping.py b/examples/visualization/mapping.py new file mode 100644 index 00000000..9aa69d5e --- /dev/null +++ b/examples/visualization/mapping.py @@ -0,0 +1,23 @@ +from .classification import visualize_age_gender, visualize_classification +from .detection import visualize_detections, visualize_line_detections +from .image import visualize_image +from .keypoints import visualize_keypoints +from .segmentation import visualize_segmentation + +parser_mapping = { + "YuNetParser": visualize_detections, + "SCRFDParser": visualize_detections, + "MPPalmDetectionParser": visualize_detections, + "YOLO": visualize_detections, + "SSD": visualize_detections, + "SegmentationParser": visualize_segmentation, + "MLSDParser": visualize_line_detections, + "KeypointParser": visualize_keypoints, + "HRNetParser": visualize_keypoints, + "SuperAnimalParser": visualize_keypoints, + "MPHandLandmarkParser": visualize_keypoints, + "ClassificationParser": visualize_classification, + "ImageOutputParser": visualize_image, + "MonocularDepthParser": visualize_image, + "AgeGenderParser": visualize_age_gender, +} diff --git a/examples/visualization/messages.py b/examples/visualization/messages.py new file mode 100644 index 00000000..a74aa8df --- /dev/null +++ b/examples/visualization/messages.py @@ -0,0 +1,52 @@ +import depthai as dai + +from depthai_nodes.ml.messages import AgeGender, Classifications, Keypoints, Lines + + +def parse_detection_message(message: dai.ImgDetections): + """Parses the detection message and returns the detections.""" + detections = message.detections + return detections + + +def parse_line_detection_message(message: Lines): + """Parses the line detection message and returns the lines.""" + lines = message.lines + return lines + + +def parse_segmentation_message(message: dai.ImgFrame): + """Parses the segmentation message and returns the mask.""" + mask = message.getFrame() + mask = mask.reshape(mask.shape[0], mask.shape[1]) + return mask + + +def parse_keypoints_message(message: Keypoints): + """Parses the keypoints message and returns the keypoints.""" + keypoints = message.keypoints + return keypoints + + +def parse_classification_message(message: Classifications): + """Parses the classification message and returns the classification.""" + classes = message.classes + scores = message.scores + return classes, scores + + +def parse_image_message(message: dai.ImgFrame): + """Parses the image message and returns the image.""" + image = message.getFrame() + return image + + +def parser_age_gender_message(message: AgeGender): + """Parses the age-gender message and return the age and scores for all genders.""" + + age = message.age + gender = message.gender + gender_scores = gender.scores + gender_classes = gender.classes + + return age, gender_classes, gender_scores diff --git a/examples/visualization/segmentation.py b/examples/visualization/segmentation.py new file mode 100644 index 00000000..0e33eb67 --- /dev/null +++ b/examples/visualization/segmentation.py @@ -0,0 +1,41 @@ +import cv2 +import depthai as dai +import numpy as np + +from .colors import get_adas_colors, get_ewasr_colors, get_selfie_colors +from .messages import parse_segmentation_message + + +def visualize_segmentation( + frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict +): + mask = parse_segmentation_message(message) + frame = cv2.resize(frame, (mask.shape[1], mask.shape[0])) + + n_classes = extraParams.get("n_classes", None) + + if n_classes is None: + raise ValueError("Number of classes not provided in NN archive metadata.") + + if n_classes == 2: + COLORS = get_selfie_colors() + elif n_classes == 3: + COLORS = get_ewasr_colors() + else: + COLORS = get_adas_colors() + + colormap = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + for class_id, color in enumerate(COLORS): + m = mask == class_id + colormap[m] = color + + alpha = 0.5 + overlay = cv2.addWeighted(colormap, alpha, frame, 1 - alpha, 0) + + cv2.imshow("Segmentation", overlay) + + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/visualize.py b/examples/visualization/visualize.py new file mode 100644 index 00000000..c4ce279b --- /dev/null +++ b/examples/visualization/visualize.py @@ -0,0 +1,12 @@ +import depthai as dai + +from .mapping import parser_mapping + + +def visualize( + frame: dai.ImgFrame, message: dai.Buffer, parser_name: str, extraParams: dict +): + """Calls the appropriate visualizer based on the parser name and returns True if the + pipeline should be stopped.""" + visualizer = parser_mapping[parser_name] + return visualizer(frame, message, extraParams)