diff --git a/examples/utils/parser.py b/examples/utils/parser.py index 939f45d7..cda5a372 100644 --- a/examples/utils/parser.py +++ b/examples/utils/parser.py @@ -7,6 +7,7 @@ SCRFDParser, SegmentationParser, XFeatParser, + YOLOExtendedParser, ) @@ -86,6 +87,17 @@ def setup_xfeat_parser(parser: XFeatParser, params: dict): ) +def setup_yolo_extended_parser(parser: YOLOExtendedParser, params: dict): + """Setup the YOLO parser with the required metadata.""" + try: + n_classes = params["n_classes"] + parser.setNumClasses(n_classes) + except Exception: + print( + "This NN archive does not have required metadata for YOLOExtendedParser. Skipping setup..." + ) + + def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str): """Setup the parser with the NN archive.""" @@ -105,3 +117,5 @@ def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_nam setup_monocular_depth_parser(parser, extraParams) elif parser_name == "XFeatParser": setup_xfeat_parser(parser, extraParams) + elif parser_name == "YOLOExtendedParser": + setup_yolo_extended_parser(parser, extraParams) diff --git a/examples/visualization/colors.py b/examples/visualization/colors.py index 995ada6c..a41c6bdd 100644 --- a/examples/visualization/colors.py +++ b/examples/visualization/colors.py @@ -37,3 +37,30 @@ def get_ewasr_colors(): (0, 0, 0), # class 2 - green ] return colors + + +def get_yolo_colors(): + colors = [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ] + + return colors diff --git a/examples/visualization/detection.py b/examples/visualization/detection.py index 50f63315..bfc1c6c7 100644 --- a/examples/visualization/detection.py +++ b/examples/visualization/detection.py @@ -1,9 +1,15 @@ import cv2 import depthai as dai +import numpy as np -from depthai_nodes.ml.messages import Lines +from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines -from .messages import parse_detection_message, parse_line_detection_message +from .colors import get_yolo_colors +from .messages import ( + parse_detection_message, + parse_line_detection_message, + parse_yolo_kpts_message, +) def visualize_detections( @@ -100,3 +106,72 @@ def visualize_line_detections(frame: dai.ImgFrame, message: Lines, extraParams: return True return False + + +def visualize_yolo_extended( + frame: dai.ImgFrame, message: ImgDetectionsExtended, extraParams: dict +): + """Visualizes the YOLO pose detections or instance segmentation on the frame.""" + detections = parse_yolo_kpts_message(message) + + colors = get_yolo_colors() + classes = extraParams.get("classes", None) + if classes is None: + raise ValueError("Classes are required for visualization.") + task = extraParams.get("n_keypoints", None) + if task is None: + task = "segmentation" + else: + task = "keypoints" + + overlay = np.zeros_like(frame) + + for detection in detections: + xmin, ymin, xmax, ymax = ( + detection.xmin, + detection.ymin, + detection.xmax, + detection.ymax, + ) + cv2.rectangle( + frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 2 + ) + cv2.putText( + frame, + f"{detection.confidence * 100:.2f}%", + (int(xmin) + 10, int(ymin) + 20), + cv2.FONT_HERSHEY_TRIPLEX, + 0.5, + 255, + ) + cv2.putText( + frame, + f"{classes[detection.label]}", + (int(xmin) + 10, int(ymin) + 40), + cv2.FONT_HERSHEY_TRIPLEX, + 0.5, + 255, + ) + + if task == "keypoints": + keypoints = detection.keypoints + for keypoint in keypoints: + x, y, visibility = keypoint[0], keypoint[1], keypoint[2] + if visibility > 0.8: + cv2.circle(frame, (int(x), int(y)), 2, (0, 255, 0), -1) + else: + mask = detection.mask + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0])) + mask = detection.mask + mask = cv2.resize(mask, (512, 288)) + label = detection.label + overlay[mask > 0] = colors[(label % len(colors))] + + if task == "segmentation": + frame = cv2.addWeighted(overlay, 0.8, frame, 0.5, 0, None) + cv2.imshow("YOLO Pose Estimation", frame) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/mapping.py b/examples/visualization/mapping.py index 9aa69d5e..7891da97 100644 --- a/examples/visualization/mapping.py +++ b/examples/visualization/mapping.py @@ -1,5 +1,9 @@ from .classification import visualize_age_gender, visualize_classification -from .detection import visualize_detections, visualize_line_detections +from .detection import ( + visualize_detections, + visualize_line_detections, + visualize_yolo_extended, +) from .image import visualize_image from .keypoints import visualize_keypoints from .segmentation import visualize_segmentation @@ -20,4 +24,5 @@ "ImageOutputParser": visualize_image, "MonocularDepthParser": visualize_image, "AgeGenderParser": visualize_age_gender, + "YOLOExtendedParser": visualize_yolo_extended, } diff --git a/examples/visualization/messages.py b/examples/visualization/messages.py index a74aa8df..f735a778 100644 --- a/examples/visualization/messages.py +++ b/examples/visualization/messages.py @@ -1,6 +1,12 @@ import depthai as dai -from depthai_nodes.ml.messages import AgeGender, Classifications, Keypoints, Lines +from depthai_nodes.ml.messages import ( + AgeGender, + Classifications, + ImgDetectionsExtended, + Keypoints, + Lines, +) def parse_detection_message(message: dai.ImgDetections): @@ -50,3 +56,9 @@ def parser_age_gender_message(message: AgeGender): gender_classes = gender.classes return age, gender_classes, gender_scores + + +def parse_yolo_kpts_message(message: ImgDetectionsExtended): + """Parses the yolo keypoints message and returns the keypoints.""" + detections = message.detections + return detections