From 38be7e83f4d560ca753acd8c66c0cc03ce1e7909 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Tue, 28 May 2024 09:30:59 +0200 Subject: [PATCH 01/29] New import path. --- ml/postprocessing/scrfd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/postprocessing/scrfd.py b/ml/postprocessing/scrfd.py index 939b1dab..379ca063 100644 --- a/ml/postprocessing/scrfd.py +++ b/ml/postprocessing/scrfd.py @@ -2,7 +2,7 @@ import numpy as np import cv2 -from ..custom_messages.img_detections import ImgDetectionsWithKeypoints +from ..messages import ImgDetectionsWithKeypoints class SCRFDParser(dai.node.ThreadedHostNode): def __init__( From c7a566ffb6be4f87f116c5b74b710770426dcbce Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Tue, 28 May 2024 09:32:02 +0200 Subject: [PATCH 02/29] Segmentation msg creation and new selfi seg. output. --- ml/postprocessing/selfie_seg.py | 17 ++++------------ .../message_creation/depth_segmentation.py | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/ml/postprocessing/selfie_seg.py b/ml/postprocessing/selfie_seg.py index 4c9979e1..24eb620e 100644 --- a/ml/postprocessing/selfie_seg.py +++ b/ml/postprocessing/selfie_seg.py @@ -1,13 +1,13 @@ import depthai as dai import numpy as np import cv2 +from .utils.message_creation.depth_segmentation import create_segmentation_msg class SeflieSegParser(dai.node.ThreadedHostNode): def __init__( self, threshold=0.5, input_size=(256, 144), - mask_color=[0, 255, 0], ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) @@ -15,10 +15,6 @@ def __init__( self.input_size = input_size self.threshold = threshold - self.mask_color = mask_color - - def setMaskColor(self, mask_color): - self.mask_color = mask_color def setConfidenceThreshold(self, threshold): self.threshold = threshold @@ -46,13 +42,8 @@ def run(self): segmentation_mask = output.getTensor("output") segmentation_mask = segmentation_mask[0].squeeze() > self.threshold - overlay_image = np.ones((segmentation_mask.shape[0], segmentation_mask.shape[1], 3), dtype=np.uint8) * 255 - overlay_image[segmentation_mask] = self.mask_color - - imgFrame = dai.ImgFrame() - imgFrame.setFrame(overlay_image) - imgFrame.setWidth(overlay_image.shape[1]) - imgFrame.setHeight(overlay_image.shape[0]) - imgFrame.setType(dai.ImgFrame.Type.BGR888i) + overlay_image = np.zeros((segmentation_mask.shape[0], segmentation_mask.shape[1], 1), dtype=np.uint8) + overlay_image[segmentation_mask] = 1 + imgFrame = create_segmentation_msg(overlay_image) self.out.send(imgFrame) \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/depth_segmentation.py b/ml/postprocessing/utils/message_creation/depth_segmentation.py index e69de29b..8c6c041c 100644 --- a/ml/postprocessing/utils/message_creation/depth_segmentation.py +++ b/ml/postprocessing/utils/message_creation/depth_segmentation.py @@ -0,0 +1,20 @@ +import depthai as dai +import numpy as np + +def create_segmentation_msg(mask: np.array) -> dai.ImgFrame: + """ + Create a message for the segmentation mask. Mask is of the shape (H, W, 1). In the third dimesion we specify the class. + + Args: + mask (np.array): The segmentation mask. + + Returns: + dai.ImgFrame: The message containing the segmentation mask. + """ + imgFrame = dai.ImgFrame() + imgFrame.setFrame(mask) + imgFrame.setWidth(mask.shape[1]) + imgFrame.setHeight(mask.shape[0]) + imgFrame.setType(dai.ImgFrame.Type.GRAY8) + + return imgFrame \ No newline at end of file From e7dc367144eae2ad6d908966fe688dc904ed52df Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Tue, 28 May 2024 09:32:24 +0200 Subject: [PATCH 03/29] Nodes for MP hands. --- ml/messages/__init__.py | 3 +- ml/messages/landmarks.py | 9 + ml/postprocessing/mp_hand_landmark.py | 76 +++++ ml/postprocessing/mp_palm_detection.py | 98 ++++++ ml/postprocessing/utils/medipipe_utils.py | 390 ++++++++++++++++++++++ 5 files changed, 575 insertions(+), 1 deletion(-) create mode 100644 ml/messages/landmarks.py create mode 100644 ml/postprocessing/mp_hand_landmark.py create mode 100644 ml/postprocessing/mp_palm_detection.py create mode 100644 ml/postprocessing/utils/medipipe_utils.py diff --git a/ml/messages/__init__.py b/ml/messages/__init__.py index 795161f2..6546ea66 100644 --- a/ml/messages/__init__.py +++ b/ml/messages/__init__.py @@ -1 +1,2 @@ -from .img_detections import ImgDetectionsWithKeypoints \ No newline at end of file +from .img_detections import ImgDetectionsWithKeypoints +from .landmarks import HandLandmarks \ No newline at end of file diff --git a/ml/messages/landmarks.py b/ml/messages/landmarks.py new file mode 100644 index 00000000..72dec25d --- /dev/null +++ b/ml/messages/landmarks.py @@ -0,0 +1,9 @@ +import depthai as dai +from typing import List + +class HandLandmarks(dai.Buffer): + def __init__(self): + dai.Buffer.__init__(self) + self.confidence: float = 0.0 + self.handness: float = 0.0 + self.landmarks: List[dai.Point3f] = [] \ No newline at end of file diff --git a/ml/postprocessing/mp_hand_landmark.py b/ml/postprocessing/mp_hand_landmark.py new file mode 100644 index 00000000..476c71da --- /dev/null +++ b/ml/postprocessing/mp_hand_landmark.py @@ -0,0 +1,76 @@ +import depthai as dai +import numpy as np +import cv2 + +from ..messages import HandLandmarks + +class MPHandLandmarkParser(dai.node.ThreadedHostNode): + def __init__( + self, + score_threshold=0.5, + handdedness_threshold=0.5, + input_size=(224, 224) + ): + dai.node.ThreadedHostNode.__init__(self) + self.input = dai.Node.Input(self) + self.out = dai.Node.Output(self) + + self.score_threshold = score_threshold + self.input_size = input_size + self.handdedness_threshold = handdedness_threshold + + def setScoreThreshold(self, threshold): + self.score_threshold = threshold + + def setHandednessThreshold(self, threshold): + self.handdedness_threshold = threshold + + def setInputSize(self, width, height): + self.input_size = (width, height) + + def run(self): + """ + Postprocessing logic for SCRFD model. + + Returns: + ... + """ + + while self.isRunning(): + + try: + output: dai.NNData = self.input.get() + except dai.MessageQueue.QueueException as e: + break # Pipeline was stopped + + print('MP Hand landmark node') + print(f"Layer names = {output.getAllLayerNames()}") + + tensorInfo = output.getTensorInfo("Identity") + landmarks = output.getTensor(f"Identity").reshape(21, 3).astype(np.float32) + landmarks = (landmarks - tensorInfo.qpZp) * tensorInfo.qpScale + tensorInfo = output.getTensorInfo("Identity_1") + hand_score = output.getTensor(f"Identity_1").reshape(-1).astype(np.float32) + hand_score = (hand_score - tensorInfo.qpZp) * tensorInfo.qpScale + hand_score = hand_score[0] + tensorInfo = output.getTensorInfo("Identity_2") + handdedness = output.getTensor(f"Identity_2").reshape(-1).astype(np.float32) + handdedness = (handdedness - tensorInfo.qpZp) * tensorInfo.qpScale + handdedness = handdedness[0] + + hand_landmarks_msg = HandLandmarks() + if hand_score < self.score_threshold: + hand_landmarks_msg.landmarks = [] + hand_landmarks_msg.confidence = hand_score + hand_landmarks_msg.handedness = handdedness + self.out.send(hand_landmarks_msg) + else: + hand_landmarks_msg.confidence = hand_score + hand_landmarks_msg.handedness = handdedness + for i in range(21): + pt = dai.Point3f() + pt.x = landmarks[i][0] + pt.y = landmarks[i][1] + pt.z = landmarks[i][2] + hand_landmarks_msg.landmarks.append(pt) + self.out.send(hand_landmarks_msg) \ No newline at end of file diff --git a/ml/postprocessing/mp_palm_detection.py b/ml/postprocessing/mp_palm_detection.py new file mode 100644 index 00000000..f09c6c09 --- /dev/null +++ b/ml/postprocessing/mp_palm_detection.py @@ -0,0 +1,98 @@ +import depthai as dai +import numpy as np +import cv2 + +from ..messages import ImgDetectionsWithKeypoints +from .utils.medipipe_utils import generate_handtracker_anchors, decode_bboxes, rect_transformation, detections_to_rect + +class MPPalmDetectionParser(dai.node.ThreadedHostNode): + def __init__( + self, + score_threshold=0.5, + nms_threshold=0.5, + top_k=100, + input_size=(192, 192), # WH + ): + dai.node.ThreadedHostNode.__init__(self) + self.input = dai.Node.Input(self) + self.out = dai.Node.Output(self) + + self.input_size = input_size + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.top_k = top_k + + def setConfidenceThreshold(self, threshold): + self.score_threshold = threshold + + def setNMSThreshold(self, threshold): + self.nms_threshold = threshold + + def setTopK(self, top_k): + self.top_k = top_k + + def setInputSize(self, width, height): + self.input_size = (width, height) + + def run(self): + """ + Postprocessing logic for SCRFD model. + + Returns: + ... + """ + + while self.isRunning(): + + try: + output: dai.NNData = self.input.get() + except dai.MessageQueue.QueueException as e: + break # Pipeline was stopped + + print('MP Palm detection node') + print(f"Layer names = {output.getAllLayerNames()}") + + tensorInfo = output.getTensorInfo("Identity") + bboxes = output.getTensor(f"Identity").reshape(2016, 18).astype(np.float32) + bboxes = (bboxes - tensorInfo.qpZp) * tensorInfo.qpScale + tensorInfo = output.getTensorInfo("Identity_1") + scores = output.getTensor(f"Identity_1").reshape(2016).astype(np.float32) + scores = (scores - tensorInfo.qpZp) * tensorInfo.qpScale + + anchors = generate_handtracker_anchors(192, 192) + decoded_bboxes = decode_bboxes(0.5, scores, bboxes, anchors, scale=192) + detections_to_rect(decoded_bboxes) + rect_transformation(decoded_bboxes, 192, 192) + + bboxes = [] + scores = [] + + for hand in decoded_bboxes: + extended_points = hand.rect_points + xmin = int(min(extended_points[0][0], extended_points[1][0])) + ymin = int(min(extended_points[0][1], extended_points[1][1])) + xmax = int(max(extended_points[2][0], extended_points[3][0])) + ymax = int(max(extended_points[2][1], extended_points[3][1])) + + bboxes.append([xmin, ymin, xmax, ymax]) + scores.append(hand.pd_score) + + indices = cv2.dnn.NMSBoxes(bboxes, scores, self.score_threshold, self.nms_threshold, top_k=self.top_k) + bboxes = np.array(bboxes)[indices] + scores = np.array(scores)[indices] + + detections = [] + for bbox, score in zip(bboxes, scores): + detection = dai.ImgDetection() + detection.confidence = score + detection.label = 0 + detection.xmin = bbox[0] + detection.ymin = bbox[1] + detection.xmax = bbox[2] + detection.ymax = bbox[3] + detections.append(detection) + + detections_msg = dai.ImgDetections() + detections_msg.detections = detections + + self.out.send(detections_msg) \ No newline at end of file diff --git a/ml/postprocessing/utils/medipipe_utils.py b/ml/postprocessing/utils/medipipe_utils.py new file mode 100644 index 00000000..94fe3c9a --- /dev/null +++ b/ml/postprocessing/utils/medipipe_utils.py @@ -0,0 +1,390 @@ +import math +import numpy as np +from collections import namedtuple + +class HandRegion: + """ + Attributes: + pd_score : detection score + pd_box : detection box [x, y, w, h], normalized [0,1] in the squared image + pd_kps : detection keypoints coordinates [x, y], normalized [0,1] in the squared image + rect_x_center, rect_y_center : center coordinates of the rotated bounding rectangle, normalized [0,1] in the squared image + rect_w, rect_h : width and height of the rotated bounding rectangle, normalized in the squared image (may be > 1) + rotation : rotation angle of rotated bounding rectangle with y-axis in radian + rect_x_center_a, rect_y_center_a : center coordinates of the rotated bounding rectangle, in pixels in the squared image + rect_w, rect_h : width and height of the rotated bounding rectangle, in pixels in the squared image + rect_points : list of the 4 points coordinates of the rotated bounding rectangle, in pixels + expressed in the squared image during processing, + expressed in the source rectangular image when returned to the user + lm_score: global landmark score + norm_landmarks : 3D landmarks coordinates in the rotated bounding rectangle, normalized [0,1] + landmarks : 2D landmark coordinates in pixel in the source rectangular image + world_landmarks : 3D landmark coordinates in meter + handedness: float between 0. and 1., > 0.5 for right hand, < 0.5 for left hand, + label: "left" or "right", handedness translated in a string, + xyz: real 3D world coordinates of the wrist landmark, or of the palm center (if landmarks are not used), + xyz_zone: (left, top, right, bottom), pixel coordinates in the source rectangular image + of the rectangular zone used to estimate the depth + gesture: (optional, set in recognize_gesture() when use_gesture==True) string corresponding to recognized gesture ("ONE","TWO","THREE","FOUR","FIVE","FIST","OK","PEACE") + or None if no gesture has been recognized + """ + def __init__(self, pd_score=None, pd_box=None, pd_kps=None): + self.pd_score = pd_score # Palm detection score + self.pd_box = pd_box # Palm detection box [x, y, w, h] normalized + self.pd_kps = pd_kps # Palm detection keypoints + + def get_rotated_world_landmarks(self): + world_landmarks_rotated = self.world_landmarks.copy() + sin_rot = math.sin(self.rotation) + cos_rot = math.cos(self.rotation) + rot_m = np.array([[cos_rot, sin_rot], [-sin_rot, cos_rot]]) + world_landmarks_rotated[:,:2] = np.dot(world_landmarks_rotated[:,:2], rot_m) + return world_landmarks_rotated + + def print(self): + attrs = vars(self) + print('\n'.join("%s: %s" % item for item in attrs.items())) + +SSDAnchorOptions = namedtuple('SSDAnchorOptions',[ + 'num_layers', + 'min_scale', + 'max_scale', + 'input_size_height', + 'input_size_width', + 'anchor_offset_x', + 'anchor_offset_y', + 'strides', + 'aspect_ratios', + 'reduce_boxes_in_lowest_layer', + 'interpolated_scale_aspect_ratio', + 'fixed_anchor_size']) + +def calculate_scale(min_scale, max_scale, stride_index, num_strides): + if num_strides == 1: + return (min_scale + max_scale) / 2 + else: + return min_scale + (max_scale - min_scale) * stride_index / (num_strides - 1) + +def generate_anchors(options): + """ + option : SSDAnchorOptions + # https://github.com/google/mediapipe/blob/master/mediapipe/calculators/tflite/ssd_anchors_calculator.cc + """ + anchors = [] + layer_id = 0 + n_strides = len(options.strides) + while layer_id < n_strides: + anchor_height = [] + anchor_width = [] + aspect_ratios = [] + scales = [] + # For same strides, we merge the anchors in the same order. + last_same_stride_layer = layer_id + while last_same_stride_layer < n_strides and \ + options.strides[last_same_stride_layer] == options.strides[layer_id]: + scale = calculate_scale(options.min_scale, options.max_scale, last_same_stride_layer, n_strides) + if last_same_stride_layer == 0 and options.reduce_boxes_in_lowest_layer: + # For first layer, it can be specified to use predefined anchors. + aspect_ratios += [1.0, 2.0, 0.5] + scales += [0.1, scale, scale] + else: + aspect_ratios += options.aspect_ratios + scales += [scale] * len(options.aspect_ratios) + if options.interpolated_scale_aspect_ratio > 0: + if last_same_stride_layer == n_strides -1: + scale_next = 1.0 + else: + scale_next = calculate_scale(options.min_scale, options.max_scale, last_same_stride_layer+1, n_strides) + scales.append(math.sqrt(scale * scale_next)) + aspect_ratios.append(options.interpolated_scale_aspect_ratio) + last_same_stride_layer += 1 + + for i,r in enumerate(aspect_ratios): + ratio_sqrts = math.sqrt(r) + anchor_height.append(scales[i] / ratio_sqrts) + anchor_width.append(scales[i] * ratio_sqrts) + + stride = options.strides[layer_id] + feature_map_height = math.ceil(options.input_size_height / stride) + feature_map_width = math.ceil(options.input_size_width / stride) + + for y in range(feature_map_height): + for x in range(feature_map_width): + for anchor_id in range(len(anchor_height)): + x_center = (x + options.anchor_offset_x) / feature_map_width + y_center = (y + options.anchor_offset_y) / feature_map_height + # new_anchor = Anchor(x_center=x_center, y_center=y_center) + if options.fixed_anchor_size: + new_anchor = [x_center, y_center, 1.0, 1.0] + # new_anchor.w = 1.0 + # new_anchor.h = 1.0 + else: + new_anchor = [x_center, y_center, anchor_width[anchor_id], anchor_height[anchor_id]] + # new_anchor.w = anchor_width[anchor_id] + # new_anchor.h = anchor_height[anchor_id] + anchors.append(new_anchor) + + layer_id = last_same_stride_layer + return np.array(anchors) + +def generate_handtracker_anchors(input_size_width, input_size_height): + # https://github.com/google/mediapipe/blob/master/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt + anchor_options = SSDAnchorOptions(num_layers=4, + min_scale=0.1484375, + max_scale=0.75, + input_size_height=input_size_height, + input_size_width=input_size_width, + anchor_offset_x=0.5, + anchor_offset_y=0.5, + strides=[8, 16, 16, 16], + aspect_ratios= [1.0], + reduce_boxes_in_lowest_layer=False, + interpolated_scale_aspect_ratio=1.0, + fixed_anchor_size=True) + return generate_anchors(anchor_options) + +def decode_bboxes(score_thresh, scores, bboxes, anchors, scale=128, best_only=False): + """ + wi, hi : NN input shape + mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc + # Decodes the detection tensors generated by the model, based on + # the SSD anchors and the specification in the options, into a vector of + # detections. Each detection describes a detected object. + + https://github.com/google/mediapipe/blob/master/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt : + node { + calculator: "TensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:unfiltered_detections" + options: { + [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { + num_classes: 1 + num_boxes: 896 + num_coords: 18 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 7 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.5 + } + } + } + node { + calculator: "TensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:unfiltered_detections" + options: { + [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { + num_classes: 1 + num_boxes: 2016 + num_coords: 18 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 7 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + + x_scale: 192.0 + y_scale: 192.0 + w_scale: 192.0 + h_scale: 192.0 + min_score_thresh: 0.5 + } + } + } + + scores: shape = [number of anchors 896 or 2016] + bboxes: shape = [ number of anchors x 18], 18 = 4 (bounding box : (cx,cy,w,h) + 14 (7 palm keypoints) + """ + regions = [] + scores = 1 / (1 + np.exp(-scores)) + if best_only: + best_id = np.argmax(scores) + if scores[best_id] < score_thresh: return regions + det_scores = scores[best_id:best_id+1] + det_bboxes2 = bboxes[best_id:best_id+1] + det_anchors = anchors[best_id:best_id+1] + else: + detection_mask = scores > score_thresh + det_scores = scores[detection_mask] + if det_scores.size == 0: return regions + det_bboxes2 = bboxes[detection_mask] + det_anchors = anchors[detection_mask] + + # scale = 128 # x_scale, y_scale, w_scale, h_scale + # scale = 192 # x_scale, y_scale, w_scale, h_scale + + # cx, cy, w, h = bboxes[i,:4] + # cx = cx * anchor.w / wi + anchor.x_center + # cy = cy * anchor.h / hi + anchor.y_center + # lx = lx * anchor.w / wi + anchor.x_center + # ly = ly * anchor.h / hi + anchor.y_center + det_bboxes = det_bboxes2* np.tile(det_anchors[:,2:4], 9) / scale + np.tile(det_anchors[:,0:2],9) + # w = w * anchor.w / wi (in the prvious line, we add anchor.x_center and anchor.y_center to w and h, we need to substract them now) + # h = h * anchor.h / hi + det_bboxes[:,2:4] = det_bboxes[:,2:4] - det_anchors[:,0:2] + # box = [cx - w*0.5, cy - h*0.5, w, h] + det_bboxes[:,0:2] = det_bboxes[:,0:2] - det_bboxes[:,3:4] * 0.5 + + for i in range(det_bboxes.shape[0]): + score = det_scores[i] + box = det_bboxes[i,0:4] + # Decoded detection boxes could have negative values for width/height due + # to model prediction. Filter out those boxes + if box[2] < 0 or box[3] < 0: continue + kps = [] + # 0 : wrist + # 1 : index finger joint + # 2 : middle finger joint + # 3 : ring finger joint + # 4 : little finger joint + # 5 : + # 6 : thumb joint + # for j, name in enumerate(["0", "1", "2", "3", "4", "5", "6"]): + # kps[name] = det_bboxes[i,4+j*2:6+j*2] + for kp in range(7): + kps.append(det_bboxes[i,4+kp*2:6+kp*2]) + regions.append(HandRegion(float(score), box, kps)) + return regions + +def rect_transformation(regions, w, h): + """ + w, h : image input shape + """ + # https://github.com/google/mediapipe/blob/master/mediapipe/modules/hand_landmark/palm_detection_detection_to_roi.pbtxt + # # Expands and shifts the rectangle that contains the palm so that it's likely + # # to cover the entire hand. + # node { + # calculator: "RectTransformationCalculator" + # input_stream: "NORM_RECT:raw_roi" + # input_stream: "IMAGE_SIZE:image_size" + # output_stream: "roi" + # options: { + # [mediapipe.RectTransformationCalculatorOptions.ext] { + # scale_x: 2.6 + # scale_y: 2.6 + # shift_y: -0.5 + # square_long: true + # } + # } + # IMHO 2.9 is better than 2.6. With 2.6, it may happen that finger tips stay outside of the bouding rotated rectangle + scale_x = 2.9 + scale_y = 2.9 + shift_x = 0 + shift_y = -0.5 + for region in regions: + width = region.rect_w + height = region.rect_h + rotation = 0 + if rotation == 0: + region.rect_x_center_a = (region.rect_x_center + width * shift_x) * w + region.rect_y_center_a = (region.rect_y_center + height * shift_y) * h + else: + x_shift = (w * width * shift_x * math.cos(rotation) - h * height * shift_y * math.sin(rotation)) #/ w + y_shift = (w * width * shift_x * math.sin(rotation) + h * height * shift_y * math.cos(rotation)) #/ h + region.rect_x_center_a = region.rect_x_center*w + x_shift + region.rect_y_center_a = region.rect_y_center*h + y_shift + + # square_long: true + long_side = max(width * w, height * h) + region.rect_w_a = long_side * scale_x + region.rect_h_a = long_side * scale_y + region.rect_points = rotated_rect_to_points(region.rect_x_center_a, region.rect_y_center_a, region.rect_w_a, region.rect_h_a, region.rotation) + +def rotated_rect_to_points(cx, cy, w, h, rotation): + b = math.cos(rotation) * 0.5 + a = math.sin(rotation) * 0.5 + points = [] + p0x = cx - a*h - b*w + p0y = cy + b*h - a*w + p1x = cx + a*h - b*w + p1y = cy - b*h - a*w + p2x = int(2*cx - p0x) + p2y = int(2*cy - p0y) + p3x = int(2*cx - p1x) + p3y = int(2*cy - p1y) + p0x, p0y, p1x, p1y = int(p0x), int(p0y), int(p1x), int(p1y) + return [[p0x,p0y], [p1x,p1y], [p2x,p2y], [p3x,p3y]] + +def detections_to_rect(regions): + # https://github.com/google/mediapipe/blob/master/mediapipe/modules/hand_landmark/palm_detection_detection_to_roi.pbtxt + # # Converts results of palm detection into a rectangle (normalized by image size) + # # that encloses the palm and is rotated such that the line connecting center of + # # the wrist and MCP of the middle finger is aligned with the Y-axis of the + # # rectangle. + # node { + # calculator: "DetectionsToRectsCalculator" + # input_stream: "DETECTION:detection" + # input_stream: "IMAGE_SIZE:image_size" + # output_stream: "NORM_RECT:raw_roi" + # options: { + # [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + # rotation_vector_start_keypoint_index: 0 # Center of wrist. + # rotation_vector_end_keypoint_index: 2 # MCP of middle finger. + # rotation_vector_target_angle_degrees: 90 + # } + # } + + target_angle = math.pi * 0.5 # 90 = pi/2 + for region in regions: + + region.rect_w = region.pd_box[2] + region.rect_h = region.pd_box[3] + region.rect_x_center = region.pd_box[0] + region.rect_w / 2 + region.rect_y_center = region.pd_box[1] + region.rect_h / 2 + + x0, y0 = region.pd_kps[0] # wrist center + x1, y1 = region.pd_kps[2] # middle finger + rotation = target_angle - math.atan2(-(y1 - y0), x1 - x0) + region.rotation = normalize_radians(rotation) + +def normalize_radians(angle): + return angle - 2 * math.pi * math.floor((angle + math.pi) / (2 * math.pi)) + +def non_maxima_suppression(bboxes, iou_threshold): + if len(bboxes) == 0: + return [] + + if bboxes.dtype.kind == 'i': + bboxes = bboxes.astype('float') + + pick = [] + + x1 = bboxes[:,0] + y1 = bboxes[:,1] + x2 = bboxes[:,2] + y2 = bboxes[:,3] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + idxs = np.argsort(y2) + + while len(idxs) > 0: + last = len(idxs) - 1 + i = idxs[last] + pick.append(i) + + xx1 = np.maximum(x1[i], x1[idxs[:last]]) + yy1 = np.maximum(y1[i], y1[idxs[:last]]) + xx2 = np.minimum(x2[i], x2[idxs[:last]]) + yy2 = np.minimum(y2[i], y2[idxs[:last]]) + + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + + overlap = (w * h) / area[idxs[:last]] + + idxs = np.delete(idxs, np.concatenate(([last], np.where(overlap > iou_threshold)[0]))) + + return bboxes[pick].astype('int') \ No newline at end of file From eac159cc0eb92d58ea88696aaab5303150e5b5ef Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Tue, 28 May 2024 09:52:00 +0200 Subject: [PATCH 04/29] More general function for depth or segmentation message creation. --- ml/postprocessing/selfie_seg.py | 4 ++-- .../message_creation/depth_segmentation.py | 23 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ml/postprocessing/selfie_seg.py b/ml/postprocessing/selfie_seg.py index 24eb620e..76a9a709 100644 --- a/ml/postprocessing/selfie_seg.py +++ b/ml/postprocessing/selfie_seg.py @@ -1,7 +1,7 @@ import depthai as dai import numpy as np import cv2 -from .utils.message_creation.depth_segmentation import create_segmentation_msg +from .utils.message_creation.depth_segmentation import create_depth_segmentation_msg class SeflieSegParser(dai.node.ThreadedHostNode): def __init__( @@ -45,5 +45,5 @@ def run(self): overlay_image = np.zeros((segmentation_mask.shape[0], segmentation_mask.shape[1], 1), dtype=np.uint8) overlay_image[segmentation_mask] = 1 - imgFrame = create_segmentation_msg(overlay_image) + imgFrame = create_depth_segmentation_msg(overlay_image, 'raw8') self.out.send(imgFrame) \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/depth_segmentation.py b/ml/postprocessing/utils/message_creation/depth_segmentation.py index 8c6c041c..519a195f 100644 --- a/ml/postprocessing/utils/message_creation/depth_segmentation.py +++ b/ml/postprocessing/utils/message_creation/depth_segmentation.py @@ -1,20 +1,27 @@ import depthai as dai import numpy as np -def create_segmentation_msg(mask: np.array) -> dai.ImgFrame: +def create_depth_segmentation_msg(x: np.array, img_frame_type: str) -> dai.ImgFrame: """ - Create a message for the segmentation mask. Mask is of the shape (H, W, 1). In the third dimesion we specify the class. + Create a message for the segmentation mask or depth image. Input is of the shape (H, W, 1). + In the third dimesion we specify the class for segmentation task or depth for depth task. Args: - mask (np.array): The segmentation mask. + x (np.array): Input from the segmentation or depth node. + img_frame_type (str): Type of the image frame. Only 'raw8' and 'raw16' are supported. RAW16 is used for depth task and RAW8 for segmentation task. Returns: - dai.ImgFrame: The message containing the segmentation mask. + dai.ImgFrame: Output with segmentation classes or depth values. """ + if img_frame_type.lower() not in ["raw8", "raw16"]: + raise ValueError(f"Invalid image frame type: {img_frame_type}. Only 'raw16' and 'raw8' are supported.") imgFrame = dai.ImgFrame() - imgFrame.setFrame(mask) - imgFrame.setWidth(mask.shape[1]) - imgFrame.setHeight(mask.shape[0]) - imgFrame.setType(dai.ImgFrame.Type.GRAY8) + imgFrame.setFrame(x) + imgFrame.setWidth(x.shape[1]) + imgFrame.setHeight(x.shape[0]) + if img_frame_type.lower() == "raw8": + imgFrame.setType(dai.ImgFrame.Type.RAW8) + elif img_frame_type.lower() == "raw16": + imgFrame.setType(dai.ImgFrame.Type.RAW16) return imgFrame \ No newline at end of file From 65c8b7558538d9adf4b2b73fff499f022fdf0edc Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Tue, 28 May 2024 09:58:30 +0200 Subject: [PATCH 05/29] Typo. --- ml/messages/landmarks.py | 2 +- ml/postprocessing/mp_hand_landmark.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ml/messages/landmarks.py b/ml/messages/landmarks.py index 72dec25d..0dc8874c 100644 --- a/ml/messages/landmarks.py +++ b/ml/messages/landmarks.py @@ -5,5 +5,5 @@ class HandLandmarks(dai.Buffer): def __init__(self): dai.Buffer.__init__(self) self.confidence: float = 0.0 - self.handness: float = 0.0 + self.handdedness: float = 0.0 self.landmarks: List[dai.Point3f] = [] \ No newline at end of file diff --git a/ml/postprocessing/mp_hand_landmark.py b/ml/postprocessing/mp_hand_landmark.py index 476c71da..0a91f78b 100644 --- a/ml/postprocessing/mp_hand_landmark.py +++ b/ml/postprocessing/mp_hand_landmark.py @@ -62,11 +62,11 @@ def run(self): if hand_score < self.score_threshold: hand_landmarks_msg.landmarks = [] hand_landmarks_msg.confidence = hand_score - hand_landmarks_msg.handedness = handdedness + hand_landmarks_msg.handdedness = handdedness self.out.send(hand_landmarks_msg) else: hand_landmarks_msg.confidence = hand_score - hand_landmarks_msg.handedness = handdedness + hand_landmarks_msg.handdedness = handdedness for i in range(21): pt = dai.Point3f() pt.x = landmarks[i][0] From 869b607840f0cf7e875e7a7be5d06feb69c5d05c Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Wed, 29 May 2024 17:02:15 +0200 Subject: [PATCH 06/29] Adding HandLandmarksDescriptor. --- ml/messages/__init__.py | 2 +- ml/messages/hand_landmarks.py | 28 +++++++++++++++++++++++ ml/messages/landmarks.py | 9 -------- ml/postprocessing/mp_hand_landmark.py | 33 ++++++++------------------- 4 files changed, 39 insertions(+), 33 deletions(-) create mode 100644 ml/messages/hand_landmarks.py delete mode 100644 ml/messages/landmarks.py diff --git a/ml/messages/__init__.py b/ml/messages/__init__.py index 6546ea66..c2dd2a7f 100644 --- a/ml/messages/__init__.py +++ b/ml/messages/__init__.py @@ -1,2 +1,2 @@ from .img_detections import ImgDetectionsWithKeypoints -from .landmarks import HandLandmarks \ No newline at end of file +from .hand_landmarks import HandLandmarks \ No newline at end of file diff --git a/ml/messages/hand_landmarks.py b/ml/messages/hand_landmarks.py new file mode 100644 index 00000000..366e9be8 --- /dev/null +++ b/ml/messages/hand_landmarks.py @@ -0,0 +1,28 @@ +import depthai as dai +from typing import List + +class HandLandmarksDescriptor: + def __init__(self): + self.name = 'landmarks' + self.expected_type = dai.Point3f + + def __get__(self, instance, owner): + if instance is None: + return self + return instance.__dict__[self.name] + + def __set__(self, instance, value): + if not isinstance(value, list): + raise TypeError(f"{self.name} must be a list") + for item in value: + if not isinstance(item, self.expected_type): + raise TypeError(f"All items in {self.name} must be of type {self.expected_type}") + instance.__dict__[self.name] = value + +class HandLandmarks(dai.Buffer): + landmarks = HandLandmarksDescriptor() + def __init__(self): + dai.Buffer.__init__(self) + self.confidence: float = 0.0 + self.handdedness: float = 0.0 + self.landmarks: List[dai.Point3f] = [] \ No newline at end of file diff --git a/ml/messages/landmarks.py b/ml/messages/landmarks.py deleted file mode 100644 index 0dc8874c..00000000 --- a/ml/messages/landmarks.py +++ /dev/null @@ -1,9 +0,0 @@ -import depthai as dai -from typing import List - -class HandLandmarks(dai.Buffer): - def __init__(self): - dai.Buffer.__init__(self) - self.confidence: float = 0.0 - self.handdedness: float = 0.0 - self.landmarks: List[dai.Point3f] = [] \ No newline at end of file diff --git a/ml/postprocessing/mp_hand_landmark.py b/ml/postprocessing/mp_hand_landmark.py index 0a91f78b..197ae1f6 100644 --- a/ml/postprocessing/mp_hand_landmark.py +++ b/ml/postprocessing/mp_hand_landmark.py @@ -7,33 +7,23 @@ class MPHandLandmarkParser(dai.node.ThreadedHostNode): def __init__( self, - score_threshold=0.5, - handdedness_threshold=0.5, - input_size=(224, 224) + score_threshold=0.5 ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) self.score_threshold = score_threshold - self.input_size = input_size - self.handdedness_threshold = handdedness_threshold def setScoreThreshold(self, threshold): self.score_threshold = threshold - def setHandednessThreshold(self, threshold): - self.handdedness_threshold = threshold - - def setInputSize(self, width, height): - self.input_size = (width, height) - def run(self): """ - Postprocessing logic for SCRFD model. + Postprocessing logic for MediaPipe Hand landmark model. Returns: - ... + HandLandmarks containing 21 landmarks, confidence score, and handdedness score (right or left hand). """ while self.isRunning(): @@ -59,18 +49,15 @@ def run(self): handdedness = handdedness[0] hand_landmarks_msg = HandLandmarks() - if hand_score < self.score_threshold: - hand_landmarks_msg.landmarks = [] - hand_landmarks_msg.confidence = hand_score - hand_landmarks_msg.handdedness = handdedness - self.out.send(hand_landmarks_msg) - else: - hand_landmarks_msg.confidence = hand_score - hand_landmarks_msg.handdedness = handdedness + hand_landmarks_msg.handdedness = handdedness + hand_landmarks_msg.confidence = hand_score + hand_landmarks = [] + if hand_score >= self.score_threshold: for i in range(21): pt = dai.Point3f() pt.x = landmarks[i][0] pt.y = landmarks[i][1] pt.z = landmarks[i][2] - hand_landmarks_msg.landmarks.append(pt) - self.out.send(hand_landmarks_msg) \ No newline at end of file + hand_landmarks.append(pt) + hand_landmarks_msg.landmarks = hand_landmarks + self.out.send(hand_landmarks_msg) \ No newline at end of file From 3434ca395ef0f20335d130c21aefb54fd1cec0cc Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Wed, 29 May 2024 17:02:50 +0200 Subject: [PATCH 07/29] Removing unnecessary variables. --- ml/postprocessing/mp_palm_detection.py | 11 +++-------- ml/postprocessing/scrfd.py | 7 +------ 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/ml/postprocessing/mp_palm_detection.py b/ml/postprocessing/mp_palm_detection.py index f09c6c09..97b8e7a8 100644 --- a/ml/postprocessing/mp_palm_detection.py +++ b/ml/postprocessing/mp_palm_detection.py @@ -10,14 +10,12 @@ def __init__( self, score_threshold=0.5, nms_threshold=0.5, - top_k=100, - input_size=(192, 192), # WH + top_k=100 ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) - self.input_size = input_size self.score_threshold = score_threshold self.nms_threshold = nms_threshold self.top_k = top_k @@ -31,15 +29,12 @@ def setNMSThreshold(self, threshold): def setTopK(self, top_k): self.top_k = top_k - def setInputSize(self, width, height): - self.input_size = (width, height) - def run(self): """ - Postprocessing logic for SCRFD model. + Postprocessing logic for MediPipe Hand detection model. Returns: - ... + dai.ImgDetections containing bounding boxes of detected hands, label, and confidence score. """ while self.isRunning(): diff --git a/ml/postprocessing/scrfd.py b/ml/postprocessing/scrfd.py index 379ca063..22f609e6 100644 --- a/ml/postprocessing/scrfd.py +++ b/ml/postprocessing/scrfd.py @@ -9,14 +9,12 @@ def __init__( self, score_threshold=0.5, nms_threshold=0.5, - top_k=100, - input_size=(640, 640), # WH + top_k=100 ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) - self.input_size = input_size self.score_threshold = score_threshold self.nms_threshold = nms_threshold self.top_k = top_k @@ -30,9 +28,6 @@ def setNMSThreshold(self, threshold): def setTopK(self, top_k): self.top_k = top_k - def setInputSize(self, width, height): - self.input_size = (width, height) - def run(self): """ Postprocessing logic for SCRFD model. From d796a21a2e51bf7f3caf194c8935c94407037239 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Wed, 29 May 2024 17:06:58 +0200 Subject: [PATCH 08/29] Renaming mp hands nodes. --- .../{mp_palm_detection.py => mediapipe_hand_detection.py} | 2 +- .../{mp_hand_landmark.py => mediapipe_hand_landmarker.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename ml/postprocessing/{mp_palm_detection.py => mediapipe_hand_detection.py} (98%) rename ml/postprocessing/{mp_hand_landmark.py => mediapipe_hand_landmarker.py} (100%) diff --git a/ml/postprocessing/mp_palm_detection.py b/ml/postprocessing/mediapipe_hand_detection.py similarity index 98% rename from ml/postprocessing/mp_palm_detection.py rename to ml/postprocessing/mediapipe_hand_detection.py index 97b8e7a8..139416a1 100644 --- a/ml/postprocessing/mp_palm_detection.py +++ b/ml/postprocessing/mediapipe_hand_detection.py @@ -5,7 +5,7 @@ from ..messages import ImgDetectionsWithKeypoints from .utils.medipipe_utils import generate_handtracker_anchors, decode_bboxes, rect_transformation, detections_to_rect -class MPPalmDetectionParser(dai.node.ThreadedHostNode): +class MPHandDetectionParser(dai.node.ThreadedHostNode): def __init__( self, score_threshold=0.5, diff --git a/ml/postprocessing/mp_hand_landmark.py b/ml/postprocessing/mediapipe_hand_landmarker.py similarity index 100% rename from ml/postprocessing/mp_hand_landmark.py rename to ml/postprocessing/mediapipe_hand_landmarker.py From 26a53e4c98c87178bd5f3e864f4f60215283d975 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Wed, 29 May 2024 17:10:14 +0200 Subject: [PATCH 09/29] Rename MP selfie segmentation node. --- ...selfie_seg.py => mediapipe_selfie_segmentation.py} | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) rename ml/postprocessing/{selfie_seg.py => mediapipe_selfie_segmentation.py} (82%) diff --git a/ml/postprocessing/selfie_seg.py b/ml/postprocessing/mediapipe_selfie_segmentation.py similarity index 82% rename from ml/postprocessing/selfie_seg.py rename to ml/postprocessing/mediapipe_selfie_segmentation.py index 76a9a709..5cd882d0 100644 --- a/ml/postprocessing/selfie_seg.py +++ b/ml/postprocessing/mediapipe_selfie_segmentation.py @@ -3,31 +3,26 @@ import cv2 from .utils.message_creation.depth_segmentation import create_depth_segmentation_msg -class SeflieSegParser(dai.node.ThreadedHostNode): +class MPSeflieSegParser(dai.node.ThreadedHostNode): def __init__( self, threshold=0.5, - input_size=(256, 144), ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) - self.input_size = input_size self.threshold = threshold def setConfidenceThreshold(self, threshold): self.threshold = threshold - def setInputSize(self, width, height): - self.input_size = (width, height) - def run(self): """ - Postprocessing logic for SCRFD model. + Postprocessing logic for MediaPipe Selfie Segmentation model. Returns: - ... + Segmenation mask with two classes 1 - person, 0 - background. """ while self.isRunning(): From 5437ca8128046d07a2e87f60828dd222c37af9ba Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 30 May 2024 15:35:12 +0200 Subject: [PATCH 10/29] HandKeypoints message using property. --- ml/messages/__init__.py | 2 +- ml/messages/hand_keypoints.py | 21 ++++++++++++++ ml/messages/hand_landmarks.py | 28 ------------------- .../mediapipe_hand_landmarker.py | 4 +-- 4 files changed, 24 insertions(+), 31 deletions(-) create mode 100644 ml/messages/hand_keypoints.py delete mode 100644 ml/messages/hand_landmarks.py diff --git a/ml/messages/__init__.py b/ml/messages/__init__.py index c2dd2a7f..c239ad27 100644 --- a/ml/messages/__init__.py +++ b/ml/messages/__init__.py @@ -1,2 +1,2 @@ from .img_detections import ImgDetectionsWithKeypoints -from .hand_landmarks import HandLandmarks \ No newline at end of file +from .hand_keypoints import HandKeypoints \ No newline at end of file diff --git a/ml/messages/hand_keypoints.py b/ml/messages/hand_keypoints.py new file mode 100644 index 00000000..71114951 --- /dev/null +++ b/ml/messages/hand_keypoints.py @@ -0,0 +1,21 @@ +import depthai as dai +from typing import List + +class HandKeypoints(dai.Buffer): + def __init__(self): + dai.Buffer.__init__(self) + self.confidence: float = 0.0 + self.handdedness: float = 0.0 + self._keypoints: List[dai.Point3f] = [] + + @property + def keypoints(self) -> List[dai.Point3f]: + return self._keypoints + + @keypoints.setter + def keypoints(self, value: List[dai.Point3f]): + if not isinstance(value, list): + raise TypeError("keypoints must be a list.") + for item in value: + if not isinstance(item, dai.Point3f): + raise TypeError("All items in keypoints must be of type dai.Point3f.") diff --git a/ml/messages/hand_landmarks.py b/ml/messages/hand_landmarks.py deleted file mode 100644 index 366e9be8..00000000 --- a/ml/messages/hand_landmarks.py +++ /dev/null @@ -1,28 +0,0 @@ -import depthai as dai -from typing import List - -class HandLandmarksDescriptor: - def __init__(self): - self.name = 'landmarks' - self.expected_type = dai.Point3f - - def __get__(self, instance, owner): - if instance is None: - return self - return instance.__dict__[self.name] - - def __set__(self, instance, value): - if not isinstance(value, list): - raise TypeError(f"{self.name} must be a list") - for item in value: - if not isinstance(item, self.expected_type): - raise TypeError(f"All items in {self.name} must be of type {self.expected_type}") - instance.__dict__[self.name] = value - -class HandLandmarks(dai.Buffer): - landmarks = HandLandmarksDescriptor() - def __init__(self): - dai.Buffer.__init__(self) - self.confidence: float = 0.0 - self.handdedness: float = 0.0 - self.landmarks: List[dai.Point3f] = [] \ No newline at end of file diff --git a/ml/postprocessing/mediapipe_hand_landmarker.py b/ml/postprocessing/mediapipe_hand_landmarker.py index 197ae1f6..a2c885c4 100644 --- a/ml/postprocessing/mediapipe_hand_landmarker.py +++ b/ml/postprocessing/mediapipe_hand_landmarker.py @@ -2,7 +2,7 @@ import numpy as np import cv2 -from ..messages import HandLandmarks +from ..messages import HandKeypoints class MPHandLandmarkParser(dai.node.ThreadedHostNode): def __init__( @@ -48,7 +48,7 @@ def run(self): handdedness = (handdedness - tensorInfo.qpZp) * tensorInfo.qpScale handdedness = handdedness[0] - hand_landmarks_msg = HandLandmarks() + hand_landmarks_msg = HandKeypoints() hand_landmarks_msg.handdedness = handdedness hand_landmarks_msg.confidence = hand_score hand_landmarks = [] From 9a486bcf4ed0466527b1986a824927eb82f95210 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 30 May 2024 15:52:12 +0200 Subject: [PATCH 11/29] Depth and Segmentation msg creation. --- .../utils/message_creation/__init__.py | 7 +++++ .../utils/message_creation/depth.py | 20 ++++++++++++++ .../message_creation/depth_segmentation.py | 27 ------------------- .../utils/message_creation/segmentation.py | 20 ++++++++++++++ 4 files changed, 47 insertions(+), 27 deletions(-) create mode 100644 ml/postprocessing/utils/message_creation/__init__.py create mode 100644 ml/postprocessing/utils/message_creation/depth.py delete mode 100644 ml/postprocessing/utils/message_creation/depth_segmentation.py create mode 100644 ml/postprocessing/utils/message_creation/segmentation.py diff --git a/ml/postprocessing/utils/message_creation/__init__.py b/ml/postprocessing/utils/message_creation/__init__.py new file mode 100644 index 00000000..228d1f34 --- /dev/null +++ b/ml/postprocessing/utils/message_creation/__init__.py @@ -0,0 +1,7 @@ +from .depth import create_depth_message +from .segmentation import create_segmentation_message + +__all__ = [ + "create_depth_message", + "create_segmentation_message", +] \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/depth.py b/ml/postprocessing/utils/message_creation/depth.py new file mode 100644 index 00000000..5c28fde3 --- /dev/null +++ b/ml/postprocessing/utils/message_creation/depth.py @@ -0,0 +1,20 @@ +import depthai as dai +import numpy as np + +def create_depth_message(x: np.array) -> dai.ImgFrame: + """ + Create a message for the depth image. Input is of the shape (H, W, 1). + In the third dimesion we specify the depth in the image. + + Args: + x (np.array): Input from the depth node. + + Returns: + dai.ImgFrame: Output depth message in ImgFrame.Type.RAW16. + """ + imgFrame = dai.ImgFrame() + imgFrame.setFrame(x) + imgFrame.setWidth(x.shape[1]) + imgFrame.setHeight(x.shape[0]) + imgFrame.setType(dai.ImgFrame.Type.RAW16) + return imgFrame \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/depth_segmentation.py b/ml/postprocessing/utils/message_creation/depth_segmentation.py deleted file mode 100644 index 519a195f..00000000 --- a/ml/postprocessing/utils/message_creation/depth_segmentation.py +++ /dev/null @@ -1,27 +0,0 @@ -import depthai as dai -import numpy as np - -def create_depth_segmentation_msg(x: np.array, img_frame_type: str) -> dai.ImgFrame: - """ - Create a message for the segmentation mask or depth image. Input is of the shape (H, W, 1). - In the third dimesion we specify the class for segmentation task or depth for depth task. - - Args: - x (np.array): Input from the segmentation or depth node. - img_frame_type (str): Type of the image frame. Only 'raw8' and 'raw16' are supported. RAW16 is used for depth task and RAW8 for segmentation task. - - Returns: - dai.ImgFrame: Output with segmentation classes or depth values. - """ - if img_frame_type.lower() not in ["raw8", "raw16"]: - raise ValueError(f"Invalid image frame type: {img_frame_type}. Only 'raw16' and 'raw8' are supported.") - imgFrame = dai.ImgFrame() - imgFrame.setFrame(x) - imgFrame.setWidth(x.shape[1]) - imgFrame.setHeight(x.shape[0]) - if img_frame_type.lower() == "raw8": - imgFrame.setType(dai.ImgFrame.Type.RAW8) - elif img_frame_type.lower() == "raw16": - imgFrame.setType(dai.ImgFrame.Type.RAW16) - - return imgFrame \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/segmentation.py b/ml/postprocessing/utils/message_creation/segmentation.py new file mode 100644 index 00000000..eb2ec431 --- /dev/null +++ b/ml/postprocessing/utils/message_creation/segmentation.py @@ -0,0 +1,20 @@ +import depthai as dai +import numpy as np + +def create_segmentation_message(x: np.array) -> dai.ImgFrame: + """ + Create a message for the segmentation node output. Input is of the shape (H, W, 1). + In the third dimesion we specify the class of the segmented objects. + + Args: + x (np.array): Input from the segmentation node. + + Returns: + dai.ImgFrame: Output segmentaion message in ImgFrame.Type.RAW8. + """ + imgFrame = dai.ImgFrame() + imgFrame.setFrame(x) + imgFrame.setWidth(x.shape[1]) + imgFrame.setHeight(x.shape[0]) + imgFrame.setType(dai.ImgFrame.Type.RAW8) + return imgFrame \ No newline at end of file From c084392eb75e7c2de5902ac80ab7863ff0b0c7a9 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 30 May 2024 15:54:57 +0200 Subject: [PATCH 12/29] Segmentation msg added to host node. --- ml/postprocessing/mediapipe_selfie_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml/postprocessing/mediapipe_selfie_segmentation.py b/ml/postprocessing/mediapipe_selfie_segmentation.py index 5cd882d0..58b53b78 100644 --- a/ml/postprocessing/mediapipe_selfie_segmentation.py +++ b/ml/postprocessing/mediapipe_selfie_segmentation.py @@ -1,7 +1,7 @@ import depthai as dai import numpy as np import cv2 -from .utils.message_creation.depth_segmentation import create_depth_segmentation_msg +from .utils.message_creation import create_segmentation_message class MPSeflieSegParser(dai.node.ThreadedHostNode): def __init__( @@ -40,5 +40,5 @@ def run(self): overlay_image = np.zeros((segmentation_mask.shape[0], segmentation_mask.shape[1], 1), dtype=np.uint8) overlay_image[segmentation_mask] = 1 - imgFrame = create_depth_segmentation_msg(overlay_image, 'raw8') + imgFrame = create_segmentation_message(overlay_image) self.out.send(imgFrame) \ No newline at end of file From b1a5ce06affbedf14540903b0550e2053dafbf70 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 30 May 2024 15:58:14 +0200 Subject: [PATCH 13/29] Removed unnecessary prints. --- ml/postprocessing/mediapipe_hand_detection.py | 8 ++------ ml/postprocessing/mediapipe_hand_landmarker.py | 3 --- ml/postprocessing/mediapipe_selfie_segmentation.py | 3 --- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/ml/postprocessing/mediapipe_hand_detection.py b/ml/postprocessing/mediapipe_hand_detection.py index 139416a1..f38c186b 100644 --- a/ml/postprocessing/mediapipe_hand_detection.py +++ b/ml/postprocessing/mediapipe_hand_detection.py @@ -2,8 +2,7 @@ import numpy as np import cv2 -from ..messages import ImgDetectionsWithKeypoints -from .utils.medipipe_utils import generate_handtracker_anchors, decode_bboxes, rect_transformation, detections_to_rect +from .utils.medipipe import generate_handtracker_anchors, decode_bboxes, rect_transformation, detections_to_rect class MPHandDetectionParser(dai.node.ThreadedHostNode): def __init__( @@ -34,7 +33,7 @@ def run(self): Postprocessing logic for MediPipe Hand detection model. Returns: - dai.ImgDetections containing bounding boxes of detected hands, label, and confidence score. + dai.ImgDetections containing bounding boxes, labels, and confidence scores of detected hands. """ while self.isRunning(): @@ -44,9 +43,6 @@ def run(self): except dai.MessageQueue.QueueException as e: break # Pipeline was stopped - print('MP Palm detection node') - print(f"Layer names = {output.getAllLayerNames()}") - tensorInfo = output.getTensorInfo("Identity") bboxes = output.getTensor(f"Identity").reshape(2016, 18).astype(np.float32) bboxes = (bboxes - tensorInfo.qpZp) * tensorInfo.qpScale diff --git a/ml/postprocessing/mediapipe_hand_landmarker.py b/ml/postprocessing/mediapipe_hand_landmarker.py index a2c885c4..ac1265b4 100644 --- a/ml/postprocessing/mediapipe_hand_landmarker.py +++ b/ml/postprocessing/mediapipe_hand_landmarker.py @@ -33,9 +33,6 @@ def run(self): except dai.MessageQueue.QueueException as e: break # Pipeline was stopped - print('MP Hand landmark node') - print(f"Layer names = {output.getAllLayerNames()}") - tensorInfo = output.getTensorInfo("Identity") landmarks = output.getTensor(f"Identity").reshape(21, 3).astype(np.float32) landmarks = (landmarks - tensorInfo.qpZp) * tensorInfo.qpScale diff --git a/ml/postprocessing/mediapipe_selfie_segmentation.py b/ml/postprocessing/mediapipe_selfie_segmentation.py index 58b53b78..efb4306f 100644 --- a/ml/postprocessing/mediapipe_selfie_segmentation.py +++ b/ml/postprocessing/mediapipe_selfie_segmentation.py @@ -29,12 +29,9 @@ def run(self): try: output: dai.NNData = self.input.get() - print(f"output = {output}") except dai.MessageQueue.QueueException as e: break # Pipeline was stopped - print(f"Layer names = {output.getAllLayerNames()}") - segmentation_mask = output.getTensor("output") segmentation_mask = segmentation_mask[0].squeeze() > self.threshold overlay_image = np.zeros((segmentation_mask.shape[0], segmentation_mask.shape[1], 1), dtype=np.uint8) From d7171e842c649013ad806a4e2965195b8647bdf0 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 30 May 2024 15:58:27 +0200 Subject: [PATCH 14/29] Renamed mediapipe. --- .../utils/{medipipe_utils.py => medipipe.py} | 73 ------------------- 1 file changed, 73 deletions(-) rename ml/postprocessing/utils/{medipipe_utils.py => medipipe.py} (81%) diff --git a/ml/postprocessing/utils/medipipe_utils.py b/ml/postprocessing/utils/medipipe.py similarity index 81% rename from ml/postprocessing/utils/medipipe_utils.py rename to ml/postprocessing/utils/medipipe.py index 94fe3c9a..c3ea32c4 100644 --- a/ml/postprocessing/utils/medipipe_utils.py +++ b/ml/postprocessing/utils/medipipe.py @@ -16,35 +16,12 @@ class HandRegion: rect_points : list of the 4 points coordinates of the rotated bounding rectangle, in pixels expressed in the squared image during processing, expressed in the source rectangular image when returned to the user - lm_score: global landmark score - norm_landmarks : 3D landmarks coordinates in the rotated bounding rectangle, normalized [0,1] - landmarks : 2D landmark coordinates in pixel in the source rectangular image - world_landmarks : 3D landmark coordinates in meter - handedness: float between 0. and 1., > 0.5 for right hand, < 0.5 for left hand, - label: "left" or "right", handedness translated in a string, - xyz: real 3D world coordinates of the wrist landmark, or of the palm center (if landmarks are not used), - xyz_zone: (left, top, right, bottom), pixel coordinates in the source rectangular image - of the rectangular zone used to estimate the depth - gesture: (optional, set in recognize_gesture() when use_gesture==True) string corresponding to recognized gesture ("ONE","TWO","THREE","FOUR","FIVE","FIST","OK","PEACE") - or None if no gesture has been recognized """ def __init__(self, pd_score=None, pd_box=None, pd_kps=None): self.pd_score = pd_score # Palm detection score self.pd_box = pd_box # Palm detection box [x, y, w, h] normalized self.pd_kps = pd_kps # Palm detection keypoints - def get_rotated_world_landmarks(self): - world_landmarks_rotated = self.world_landmarks.copy() - sin_rot = math.sin(self.rotation) - cos_rot = math.cos(self.rotation) - rot_m = np.array([[cos_rot, sin_rot], [-sin_rot, cos_rot]]) - world_landmarks_rotated[:,:2] = np.dot(world_landmarks_rotated[:,:2], rot_m) - return world_landmarks_rotated - - def print(self): - attrs = vars(self) - print('\n'.join("%s: %s" % item for item in attrs.items())) - SSDAnchorOptions = namedtuple('SSDAnchorOptions',[ 'num_layers', 'min_scale', @@ -223,19 +200,8 @@ def decode_bboxes(score_thresh, scores, bboxes, anchors, scale=128, best_only=Fa det_bboxes2 = bboxes[detection_mask] det_anchors = anchors[detection_mask] - # scale = 128 # x_scale, y_scale, w_scale, h_scale - # scale = 192 # x_scale, y_scale, w_scale, h_scale - - # cx, cy, w, h = bboxes[i,:4] - # cx = cx * anchor.w / wi + anchor.x_center - # cy = cy * anchor.h / hi + anchor.y_center - # lx = lx * anchor.w / wi + anchor.x_center - # ly = ly * anchor.h / hi + anchor.y_center det_bboxes = det_bboxes2* np.tile(det_anchors[:,2:4], 9) / scale + np.tile(det_anchors[:,0:2],9) - # w = w * anchor.w / wi (in the prvious line, we add anchor.x_center and anchor.y_center to w and h, we need to substract them now) - # h = h * anchor.h / hi det_bboxes[:,2:4] = det_bboxes[:,2:4] - det_anchors[:,0:2] - # box = [cx - w*0.5, cy - h*0.5, w, h] det_bboxes[:,0:2] = det_bboxes[:,0:2] - det_bboxes[:,3:4] * 0.5 for i in range(det_bboxes.shape[0]): @@ -252,8 +218,6 @@ def decode_bboxes(score_thresh, scores, bboxes, anchors, scale=128, best_only=Fa # 4 : little finger joint # 5 : # 6 : thumb joint - # for j, name in enumerate(["0", "1", "2", "3", "4", "5", "6"]): - # kps[name] = det_bboxes[i,4+j*2:6+j*2] for kp in range(7): kps.append(det_bboxes[i,4+kp*2:6+kp*2]) regions.append(HandRegion(float(score), box, kps)) @@ -297,7 +261,6 @@ def rect_transformation(regions, w, h): region.rect_x_center_a = region.rect_x_center*w + x_shift region.rect_y_center_a = region.rect_y_center*h + y_shift - # square_long: true long_side = max(width * w, height * h) region.rect_w_a = long_side * scale_x region.rect_h_a = long_side * scale_y @@ -352,39 +315,3 @@ def detections_to_rect(regions): def normalize_radians(angle): return angle - 2 * math.pi * math.floor((angle + math.pi) / (2 * math.pi)) - -def non_maxima_suppression(bboxes, iou_threshold): - if len(bboxes) == 0: - return [] - - if bboxes.dtype.kind == 'i': - bboxes = bboxes.astype('float') - - pick = [] - - x1 = bboxes[:,0] - y1 = bboxes[:,1] - x2 = bboxes[:,2] - y2 = bboxes[:,3] - - area = (x2 - x1 + 1) * (y2 - y1 + 1) - idxs = np.argsort(y2) - - while len(idxs) > 0: - last = len(idxs) - 1 - i = idxs[last] - pick.append(i) - - xx1 = np.maximum(x1[i], x1[idxs[:last]]) - yy1 = np.maximum(y1[i], y1[idxs[:last]]) - xx2 = np.minimum(x2[i], x2[idxs[:last]]) - yy2 = np.minimum(y2[i], y2[idxs[:last]]) - - w = np.maximum(0, xx2 - xx1 + 1) - h = np.maximum(0, yy2 - yy1 + 1) - - overlap = (w * h) / area[idxs[:last]] - - idxs = np.delete(idxs, np.concatenate(([last], np.where(overlap > iou_threshold)[0]))) - - return bboxes[pick].astype('int') \ No newline at end of file From df7a38bec9b350786e18ed1dd2cee5f210888bba Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 30 May 2024 17:46:59 +0200 Subject: [PATCH 15/29] Normalize hand landmarks. --- ml/postprocessing/mediapipe_hand_landmarker.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ml/postprocessing/mediapipe_hand_landmarker.py b/ml/postprocessing/mediapipe_hand_landmarker.py index ac1265b4..b4eca56e 100644 --- a/ml/postprocessing/mediapipe_hand_landmarker.py +++ b/ml/postprocessing/mediapipe_hand_landmarker.py @@ -7,23 +7,28 @@ class MPHandLandmarkParser(dai.node.ThreadedHostNode): def __init__( self, - score_threshold=0.5 + score_threshold=0.5, + scale_factor=224 ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) self.score_threshold = score_threshold + self.scale_factor = scale_factor def setScoreThreshold(self, threshold): self.score_threshold = threshold + def setScaleFactor(self, scale_factor): + self.scale_factor = scale_factor + def run(self): """ Postprocessing logic for MediaPipe Hand landmark model. Returns: - HandLandmarks containing 21 landmarks, confidence score, and handdedness score (right or left hand). + HandLandmarks containing normalized 21 landmarks, confidence score, and handdedness score (right or left hand). """ while self.isRunning(): @@ -45,6 +50,9 @@ def run(self): handdedness = (handdedness - tensorInfo.qpZp) * tensorInfo.qpScale handdedness = handdedness[0] + # normalize landmarks + landmarks /= self.scale_factor + hand_landmarks_msg = HandKeypoints() hand_landmarks_msg.handdedness = handdedness hand_landmarks_msg.confidence = hand_score From 8d6f6b658112c5754e17aa98427f0af9ceca32d5 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Fri, 31 May 2024 11:22:24 +0200 Subject: [PATCH 16/29] Anchors and decoding in utils. --- ml/postprocessing/mediapipe_hand_detection.py | 7 ++----- ml/postprocessing/utils/detection.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 ml/postprocessing/utils/detection.py diff --git a/ml/postprocessing/mediapipe_hand_detection.py b/ml/postprocessing/mediapipe_hand_detection.py index f38c186b..6978b8df 100644 --- a/ml/postprocessing/mediapipe_hand_detection.py +++ b/ml/postprocessing/mediapipe_hand_detection.py @@ -2,7 +2,7 @@ import numpy as np import cv2 -from .utils.medipipe import generate_handtracker_anchors, decode_bboxes, rect_transformation, detections_to_rect +from .utils.detection import generate_anchors_and_decode class MPHandDetectionParser(dai.node.ThreadedHostNode): def __init__( @@ -50,10 +50,7 @@ def run(self): scores = output.getTensor(f"Identity_1").reshape(2016).astype(np.float32) scores = (scores - tensorInfo.qpZp) * tensorInfo.qpScale - anchors = generate_handtracker_anchors(192, 192) - decoded_bboxes = decode_bboxes(0.5, scores, bboxes, anchors, scale=192) - detections_to_rect(decoded_bboxes) - rect_transformation(decoded_bboxes, 192, 192) + decoded_bboxes = generate_anchors_and_decode(bboxes=bboxes, scores=scores, threshold=self.score_threshold, scale=192) bboxes = [] scores = [] diff --git a/ml/postprocessing/utils/detection.py b/ml/postprocessing/utils/detection.py new file mode 100644 index 00000000..5fd01ab0 --- /dev/null +++ b/ml/postprocessing/utils/detection.py @@ -0,0 +1,11 @@ +from .medipipe import generate_handtracker_anchors, decode_bboxes, detections_to_rect, rect_transformation + +def generate_anchors_and_decode(bboxes, scores, threshold=0.5, scale=192): + """ + Generate anchors and decode bounding boxes for mediapipe hand detection model. + """ + anchors = generate_handtracker_anchors(scale, scale) + decoded_bboxes = decode_bboxes(threshold, scores, bboxes, anchors, scale=scale) + detections_to_rect(decoded_bboxes) + rect_transformation(decoded_bboxes, scale, scale) + return decoded_bboxes \ No newline at end of file From 3e90e4b5467eb7e52ed5e31df3f9b0e29a57f0ad Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Fri, 31 May 2024 12:04:47 +0200 Subject: [PATCH 17/29] Mediapipe metadata added. --- ml/postprocessing/utils/medipipe.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/ml/postprocessing/utils/medipipe.py b/ml/postprocessing/utils/medipipe.py index c3ea32c4..7bd53cd7 100644 --- a/ml/postprocessing/utils/medipipe.py +++ b/ml/postprocessing/utils/medipipe.py @@ -1,3 +1,22 @@ +""" +mediapipe.py + +Description: This script contains utility functions for decoding the output of the MediaPipe hand tracking model. + +This script contains code that is based on or directly taken from a public GitHub repository: +https://github.com/geaxgx/depthai_hand_tracker + +Original code author(s): geaxgx + +License: MIT License + +MIT License +----------- + +Copyright (c) [2021] [geax] + +""" + import math import numpy as np from collections import namedtuple From bf463a2b28a6b65c13ef2624a091687a49a0ad32 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Fri, 31 May 2024 12:40:29 +0200 Subject: [PATCH 18/29] general multiclass segmentation parser. --- ...selfie_segmentation.py => segmentation.py} | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) rename ml/postprocessing/{mediapipe_selfie_segmentation.py => segmentation.py} (53%) diff --git a/ml/postprocessing/mediapipe_selfie_segmentation.py b/ml/postprocessing/segmentation.py similarity index 53% rename from ml/postprocessing/mediapipe_selfie_segmentation.py rename to ml/postprocessing/segmentation.py index efb4306f..9b6a96d8 100644 --- a/ml/postprocessing/mediapipe_selfie_segmentation.py +++ b/ml/postprocessing/segmentation.py @@ -3,26 +3,31 @@ import cv2 from .utils.message_creation import create_segmentation_message -class MPSeflieSegParser(dai.node.ThreadedHostNode): +class SegmentationParser(dai.node.ThreadedHostNode): def __init__( self, threshold=0.5, + num_classes=2, ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) self.threshold = threshold + self.num_classes = num_classes def setConfidenceThreshold(self, threshold): self.threshold = threshold + def setNumClasses(self, num_classes): + self.num_classes = num_classes + def run(self): """ - Postprocessing logic for MediaPipe Selfie Segmentation model. + Postprocessing logic for Segmentation model with `num_classes` classes including background at index 0. Returns: - Segmenation mask with two classes 1 - person, 0 - background. + Segmenation mask with `num_classes` classes, 0 - background. """ while self.isRunning(): @@ -33,9 +38,12 @@ def run(self): break # Pipeline was stopped segmentation_mask = output.getTensor("output") - segmentation_mask = segmentation_mask[0].squeeze() > self.threshold - overlay_image = np.zeros((segmentation_mask.shape[0], segmentation_mask.shape[1], 1), dtype=np.uint8) - overlay_image[segmentation_mask] = 1 + segmentation_mask = segmentation_mask[0] # num_clases x H x W + overlay_image = np.zeros((segmentation_mask.shape[1], segmentation_mask.shape[2], 1), dtype=np.uint8) + + for class_id in range(self.num_classes-1): + class_mask = segmentation_mask[class_id] > self.threshold + overlay_image[class_mask] = class_id + 1 imgFrame = create_segmentation_message(overlay_image) self.out.send(imgFrame) \ No newline at end of file From 26381fa25d90da72c58e87dba126cd159e4a88ff Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:58:14 +0200 Subject: [PATCH 19/29] Validating input in create functions. --- ml/postprocessing/utils/message_creation/depth.py | 6 ++++++ ml/postprocessing/utils/message_creation/segmentation.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/ml/postprocessing/utils/message_creation/depth.py b/ml/postprocessing/utils/message_creation/depth.py index 5c28fde3..d52a3907 100644 --- a/ml/postprocessing/utils/message_creation/depth.py +++ b/ml/postprocessing/utils/message_creation/depth.py @@ -12,6 +12,12 @@ def create_depth_message(x: np.array) -> dai.ImgFrame: Returns: dai.ImgFrame: Output depth message in ImgFrame.Type.RAW16. """ + + if not isinstance(x, np.ndarray): + raise ValueError(f"Expected numpy array, got {type(x)}.") + if len(x.shape) != 3: + raise ValueError(f"Expected 3D input, got {len(x.shape)}D input.") + imgFrame = dai.ImgFrame() imgFrame.setFrame(x) imgFrame.setWidth(x.shape[1]) diff --git a/ml/postprocessing/utils/message_creation/segmentation.py b/ml/postprocessing/utils/message_creation/segmentation.py index eb2ec431..f08d8eb7 100644 --- a/ml/postprocessing/utils/message_creation/segmentation.py +++ b/ml/postprocessing/utils/message_creation/segmentation.py @@ -12,6 +12,12 @@ def create_segmentation_message(x: np.array) -> dai.ImgFrame: Returns: dai.ImgFrame: Output segmentaion message in ImgFrame.Type.RAW8. """ + + if not isinstance(x, np.ndarray): + raise ValueError(f"Expected numpy array, got {type(x)}.") + if len(x.shape) != 3: + raise ValueError(f"Expected 3D input, got {len(x.shape)}D input.") + imgFrame = dai.ImgFrame() imgFrame.setFrame(x) imgFrame.setWidth(x.shape[1]) From 2deaeee46a41982ac61c2ba304404ddcbd010783 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:34:27 +0200 Subject: [PATCH 20/29] Function for creating detection and keypoint msgs. --- ml/postprocessing/mediapipe_hand_detection.py | 16 +----- .../mediapipe_hand_landmarker.py | 15 +---- .../utils/message_creation/__init__.py | 4 ++ .../utils/message_creation/detection.py | 56 +++++++++++++++++++ .../utils/message_creation/keypoints.py | 44 +++++++++++++++ 5 files changed, 108 insertions(+), 27 deletions(-) create mode 100644 ml/postprocessing/utils/message_creation/keypoints.py diff --git a/ml/postprocessing/mediapipe_hand_detection.py b/ml/postprocessing/mediapipe_hand_detection.py index 6978b8df..66abb23b 100644 --- a/ml/postprocessing/mediapipe_hand_detection.py +++ b/ml/postprocessing/mediapipe_hand_detection.py @@ -2,6 +2,7 @@ import numpy as np import cv2 +from .utils.message_creation import create_detection_message from .utils.detection import generate_anchors_and_decode class MPHandDetectionParser(dai.node.ThreadedHostNode): @@ -69,18 +70,5 @@ def run(self): bboxes = np.array(bboxes)[indices] scores = np.array(scores)[indices] - detections = [] - for bbox, score in zip(bboxes, scores): - detection = dai.ImgDetection() - detection.confidence = score - detection.label = 0 - detection.xmin = bbox[0] - detection.ymin = bbox[1] - detection.xmax = bbox[2] - detection.ymax = bbox[3] - detections.append(detection) - - detections_msg = dai.ImgDetections() - detections_msg.detections = detections - + detections_msg = create_detection_message(bboxes, scores, labels=None) self.out.send(detections_msg) \ No newline at end of file diff --git a/ml/postprocessing/mediapipe_hand_landmarker.py b/ml/postprocessing/mediapipe_hand_landmarker.py index b4eca56e..20f3b11f 100644 --- a/ml/postprocessing/mediapipe_hand_landmarker.py +++ b/ml/postprocessing/mediapipe_hand_landmarker.py @@ -2,7 +2,7 @@ import numpy as np import cv2 -from ..messages import HandKeypoints +from .utils.message_creation import create_hand_keypoints_message class MPHandLandmarkParser(dai.node.ThreadedHostNode): def __init__( @@ -53,16 +53,5 @@ def run(self): # normalize landmarks landmarks /= self.scale_factor - hand_landmarks_msg = HandKeypoints() - hand_landmarks_msg.handdedness = handdedness - hand_landmarks_msg.confidence = hand_score - hand_landmarks = [] - if hand_score >= self.score_threshold: - for i in range(21): - pt = dai.Point3f() - pt.x = landmarks[i][0] - pt.y = landmarks[i][1] - pt.z = landmarks[i][2] - hand_landmarks.append(pt) - hand_landmarks_msg.landmarks = hand_landmarks + hand_landmarks_msg = create_hand_keypoints_message(landmarks, float(handdedness), float(hand_score), self.score_threshold) self.out.send(hand_landmarks_msg) \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/__init__.py b/ml/postprocessing/utils/message_creation/__init__.py index 228d1f34..e0b7be5f 100644 --- a/ml/postprocessing/utils/message_creation/__init__.py +++ b/ml/postprocessing/utils/message_creation/__init__.py @@ -1,7 +1,11 @@ from .depth import create_depth_message from .segmentation import create_segmentation_message +from .keypoints import create_hand_keypoints_message +from .detection import create_detection_message __all__ = [ "create_depth_message", "create_segmentation_message", + "create_hand_keypoints_message", + "create_detection_message", ] \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/detection.py b/ml/postprocessing/utils/message_creation/detection.py index e69de29b..c2791b46 100644 --- a/ml/postprocessing/utils/message_creation/detection.py +++ b/ml/postprocessing/utils/message_creation/detection.py @@ -0,0 +1,56 @@ +import depthai as dai +import numpy as np +from typing import List + +def create_detection_message(bboxes: np.ndarray, scores: np.ndarray, labels: List[int] = None) -> dai.ImgDetections: + """ + Create a message for the detection. The message contains the bounding boxes, labels, and confidence scores of detected objects. + If there are no labels or we only have one class, we can set labels to None and all detections will have label set to 0. + + Args: + bboxes (np.ndarray): Detected bounding boxes of shape (N,4) meaning [...,[x_min, y_min, x_max, y_max],...]. + scores (np.ndarray): Confidence scores of detected objects of shape (N,). + labels (List[int], optional): Labels of detected objects of shape (N,). Defaults to None. + + Returns: + dai.ImgDetections: Message containing the bounding boxes, labels, and confidence scores of detected objects. + """ + + if not isinstance(bboxes, np.ndarray): + raise ValueError(f"bboxes should be numpy array, got {type(bboxes)}.") + if len(bboxes.shape) != 2: + raise ValueError(f"bboxes should be of shape (N,4) meaning [...,[x_min, y_min, x_max, y_max],...], got {bboxes.shape}.") + if bboxes.shape[1] != 4: + raise ValueError(f"bboxes 2nd dimension should be of size 4 e.g. [x_min, y_min, x_max, y_max] got {bboxes.shape[1]}.") + if not isinstance(scores, np.ndarray): + raise ValueError(f"scores should be numpy array, got {type(scores)}.") + if len(scores.shape) != 1: + raise ValueError(f"scores should be of shape (N,) meaning, got {scores.shape}.") + if scores.shape[0] != bboxes.shape[0]: + raise ValueError(f"scores should have same length as bboxes, got {scores.shape[0]} and {bboxes.shape[0]}.") + if labels is not None: + if not isinstance(labels, List): + raise ValueError(f"labels should be list, got {type(labels)}.") + for label in labels: + if not isinstance(label, int): + raise ValueError(f"labels should be list of integers, got {type(label)}.") + if len(labels) != bboxes.shape[0]: + raise ValueError(f"labels should have same length as bboxes, got {len(labels)} and {bboxes.shape[0]}.") + + if labels is None: + labels = [0 for _ in range(bboxes.shape[0])] + + detections = [] + for bbox, score, label in zip(bboxes, scores, labels): + detection = dai.ImgDetection() + detection.confidence = score + detection.label = label + detection.xmin = bbox[0] + detection.ymin = bbox[1] + detection.xmax = bbox[2] + detection.ymax = bbox[3] + detections.append(detection) + + detections_msg = dai.ImgDetections() + detections_msg.detections = detections + return detections_msg \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/keypoints.py b/ml/postprocessing/utils/message_creation/keypoints.py new file mode 100644 index 00000000..de6b544c --- /dev/null +++ b/ml/postprocessing/utils/message_creation/keypoints.py @@ -0,0 +1,44 @@ +import depthai as dai +import numpy as np +from typing import List +from ....messages import HandKeypoints + +def create_hand_keypoints_message(hand_keypoints: np.ndarray, handdedness: float, confidence: float, confidence_threshold: float) -> HandKeypoints: + """ + Create a message for the detection. The message contains the bounding boxes, labels, and confidence scores of detected objects. + If there are no labels or we only have one class, we can set labels to None and all detections will have label set to 0. + + Args: + bboxes (np.ndarray): Detected bounding boxes of shape (N,4) meaning [...,[x_min, y_min, x_max, y_max],...]. + scores (np.ndarray): Confidence scores of detected objects of shape (N,). + labels (List[int], optional): Labels of detected objects of shape (N,). Defaults to None. + + Returns: + dai.ImgDetections: Message containing the bounding boxes, labels, and confidence scores of detected objects. + """ + + if not isinstance(hand_keypoints, np.ndarray): + raise ValueError(f"hand_keypoints should be numpy array, got {type(hand_keypoints)}.") + if len(hand_keypoints.shape) != 2: + raise ValueError(f"hand_keypoints should be of shape (N,3) meaning [...,[x, y, z],...], got {hand_keypoints.shape}.") + if hand_keypoints.shape[1] != 3: + raise ValueError(f"hand_keypoints 2nd dimension should be of size 3 e.g. [x, y, z], got {hand_keypoints.shape[1]}.") + if not isinstance(handdedness, float): + raise ValueError(f"handdedness should be float, got {type(handdedness)}.") + if not isinstance(confidence, float): + raise ValueError(f"confidence should be float, got {type(confidence)}.") + + hand_keypoints_msg = HandKeypoints() + hand_keypoints_msg.handdedness = handdedness + hand_keypoints_msg.confidence = confidence + points = [] + if confidence >= confidence_threshold: + for i in range(hand_keypoints.shape[0]): + pt = dai.Point3f() + pt.x = hand_keypoints[i][0] + pt.y = hand_keypoints[i][1] + pt.z = hand_keypoints[i][2] + points.append(pt) + hand_keypoints_msg.landmarks = points + + return hand_keypoints_msg \ No newline at end of file From 293ce239ff55b612c832ced5ef5f1dbb791721f9 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:37:43 +0200 Subject: [PATCH 21/29] Moved anchors function. --- ml/postprocessing/mediapipe_hand_detection.py | 2 +- ml/postprocessing/utils/detection.py | 11 ----------- ml/postprocessing/utils/medipipe.py | 10 ++++++++++ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/ml/postprocessing/mediapipe_hand_detection.py b/ml/postprocessing/mediapipe_hand_detection.py index 66abb23b..6740daaa 100644 --- a/ml/postprocessing/mediapipe_hand_detection.py +++ b/ml/postprocessing/mediapipe_hand_detection.py @@ -3,7 +3,7 @@ import cv2 from .utils.message_creation import create_detection_message -from .utils.detection import generate_anchors_and_decode +from .utils.medipipe import generate_anchors_and_decode class MPHandDetectionParser(dai.node.ThreadedHostNode): def __init__( diff --git a/ml/postprocessing/utils/detection.py b/ml/postprocessing/utils/detection.py index 5fd01ab0..e69de29b 100644 --- a/ml/postprocessing/utils/detection.py +++ b/ml/postprocessing/utils/detection.py @@ -1,11 +0,0 @@ -from .medipipe import generate_handtracker_anchors, decode_bboxes, detections_to_rect, rect_transformation - -def generate_anchors_and_decode(bboxes, scores, threshold=0.5, scale=192): - """ - Generate anchors and decode bounding boxes for mediapipe hand detection model. - """ - anchors = generate_handtracker_anchors(scale, scale) - decoded_bboxes = decode_bboxes(threshold, scores, bboxes, anchors, scale=scale) - detections_to_rect(decoded_bboxes) - rect_transformation(decoded_bboxes, scale, scale) - return decoded_bboxes \ No newline at end of file diff --git a/ml/postprocessing/utils/medipipe.py b/ml/postprocessing/utils/medipipe.py index 7bd53cd7..ad762e1c 100644 --- a/ml/postprocessing/utils/medipipe.py +++ b/ml/postprocessing/utils/medipipe.py @@ -334,3 +334,13 @@ def detections_to_rect(regions): def normalize_radians(angle): return angle - 2 * math.pi * math.floor((angle + math.pi) / (2 * math.pi)) + +def generate_anchors_and_decode(bboxes, scores, threshold=0.5, scale=192): + """ + Generate anchors and decode bounding boxes for mediapipe hand detection model. + """ + anchors = generate_handtracker_anchors(scale, scale) + decoded_bboxes = decode_bboxes(threshold, scores, bboxes, anchors, scale=scale) + detections_to_rect(decoded_bboxes) + rect_transformation(decoded_bboxes, scale, scale) + return decoded_bboxes \ No newline at end of file From 45131f7eb578097f1230ecd9cd662161cb7330b8 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Mon, 3 Jun 2024 19:07:41 +0200 Subject: [PATCH 22/29] argmax to get segmentation classes. --- ml/postprocessing/segmentation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ml/postprocessing/segmentation.py b/ml/postprocessing/segmentation.py index 9b6a96d8..b30b1b4e 100644 --- a/ml/postprocessing/segmentation.py +++ b/ml/postprocessing/segmentation.py @@ -39,11 +39,9 @@ def run(self): segmentation_mask = output.getTensor("output") segmentation_mask = segmentation_mask[0] # num_clases x H x W - overlay_image = np.zeros((segmentation_mask.shape[1], segmentation_mask.shape[2], 1), dtype=np.uint8) - - for class_id in range(self.num_classes-1): - class_mask = segmentation_mask[class_id] > self.threshold - overlay_image[class_mask] = class_id + 1 + segmentation_mask = np.vstack((np.zeros((1, segmentation_mask.shape[1], segmentation_mask.shape[2]), dtype=np.float32), segmentation_mask)) + segmentation_mask[segmentation_mask < self.threshold] = 0 + overlay_image = np.argmax(segmentation_mask, axis=0).reshape(segmentation_mask.shape[1], segmentation_mask.shape[2], 1).astype(np.uint8) imgFrame = create_segmentation_message(overlay_image) self.out.send(imgFrame) \ No newline at end of file From 827d5c87fc9ac2201c872392690b4bfc2df86130 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Mon, 3 Jun 2024 20:00:28 +0200 Subject: [PATCH 23/29] Updated docs. --- .../utils/message_creation/keypoints.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ml/postprocessing/utils/message_creation/keypoints.py b/ml/postprocessing/utils/message_creation/keypoints.py index de6b544c..cbcf6fad 100644 --- a/ml/postprocessing/utils/message_creation/keypoints.py +++ b/ml/postprocessing/utils/message_creation/keypoints.py @@ -5,16 +5,16 @@ def create_hand_keypoints_message(hand_keypoints: np.ndarray, handdedness: float, confidence: float, confidence_threshold: float) -> HandKeypoints: """ - Create a message for the detection. The message contains the bounding boxes, labels, and confidence scores of detected objects. - If there are no labels or we only have one class, we can set labels to None and all detections will have label set to 0. + Create a message for the hand keypoint detection. The message contains the 3D coordinates of the detected hand keypoints, handdedness, and confidence score. Args: - bboxes (np.ndarray): Detected bounding boxes of shape (N,4) meaning [...,[x_min, y_min, x_max, y_max],...]. - scores (np.ndarray): Confidence scores of detected objects of shape (N,). - labels (List[int], optional): Labels of detected objects of shape (N,). Defaults to None. + hand_keypoints (np.ndarray): Detected hand keypoints of shape (N,3) meaning [...,[x, y, z],...]. + handdedness (float): Handdedness score of the detected hand (left or right). + confidence (float): Confidence score of the detected hand. + confidence_threshold (float): Confidence threshold for the overall hand. Returns: - dai.ImgDetections: Message containing the bounding boxes, labels, and confidence scores of detected objects. + HandKeypoints: Message containing the 3D coordinates of the detected hand keypoints, handdedness, and confidence score. """ if not isinstance(hand_keypoints, np.ndarray): From 25ce9cc244d470c42ec82777157e08e7c3e5e723 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Wed, 5 Jun 2024 18:21:44 +0200 Subject: [PATCH 24/29] Keypoints message. --- ml/messages/hand_keypoints.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/ml/messages/hand_keypoints.py b/ml/messages/hand_keypoints.py index 71114951..8dc7f403 100644 --- a/ml/messages/hand_keypoints.py +++ b/ml/messages/hand_keypoints.py @@ -1,11 +1,10 @@ import depthai as dai from typing import List -class HandKeypoints(dai.Buffer): + +class Keypoints(dai.Buffer): def __init__(self): - dai.Buffer.__init__(self) - self.confidence: float = 0.0 - self.handdedness: float = 0.0 + super().__init__() self._keypoints: List[dai.Point3f] = [] @property @@ -19,3 +18,31 @@ def keypoints(self, value: List[dai.Point3f]): for item in value: if not isinstance(item, dai.Point3f): raise TypeError("All items in keypoints must be of type dai.Point3f.") + self._keypoints = value + + +class HandKeypoints(Keypoints): + def __init__(self): + Keypoints.__init__(self) + self._confidence: float = 0.0 + self._handdedness: float = 0.0 + + @property + def confidence(self) -> float: + return self._confidence + + @confidence.setter + def confidence(self, value: float): + if not isinstance(value, float): + raise TypeError("confidence must be a float.") + self._confidence = value + + @property + def handdedness(self) -> float: + return self._handdedness + + @handdedness.setter + def handdedness(self, value: float): + if not isinstance(value, float): + raise TypeError("handdedness must be a float.") + self._handdedness = value \ No newline at end of file From 47b8dd74cd8eb45fc8b3b685b82d7de23d88d5d7 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:02:07 +0200 Subject: [PATCH 25/29] Generic keypoints. --- ml/messages/__init__.py | 2 +- ml/messages/{hand_keypoints.py => keypoints.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename ml/messages/{hand_keypoints.py => keypoints.py} (100%) diff --git a/ml/messages/__init__.py b/ml/messages/__init__.py index c239ad27..45868319 100644 --- a/ml/messages/__init__.py +++ b/ml/messages/__init__.py @@ -1,2 +1,2 @@ from .img_detections import ImgDetectionsWithKeypoints -from .hand_keypoints import HandKeypoints \ No newline at end of file +from .keypoints import HandKeypoints, Keypoints \ No newline at end of file diff --git a/ml/messages/hand_keypoints.py b/ml/messages/keypoints.py similarity index 100% rename from ml/messages/hand_keypoints.py rename to ml/messages/keypoints.py From 847165ffac1f082628fbcd9b06f4848d402252c0 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:06:02 +0200 Subject: [PATCH 26/29] Correction: handdedness -> handedness. --- ml/postprocessing/mediapipe_hand_landmarker.py | 8 ++++---- .../utils/message_creation/keypoints.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ml/postprocessing/mediapipe_hand_landmarker.py b/ml/postprocessing/mediapipe_hand_landmarker.py index 20f3b11f..9c5cca8e 100644 --- a/ml/postprocessing/mediapipe_hand_landmarker.py +++ b/ml/postprocessing/mediapipe_hand_landmarker.py @@ -46,12 +46,12 @@ def run(self): hand_score = (hand_score - tensorInfo.qpZp) * tensorInfo.qpScale hand_score = hand_score[0] tensorInfo = output.getTensorInfo("Identity_2") - handdedness = output.getTensor(f"Identity_2").reshape(-1).astype(np.float32) - handdedness = (handdedness - tensorInfo.qpZp) * tensorInfo.qpScale - handdedness = handdedness[0] + handedness = output.getTensor(f"Identity_2").reshape(-1).astype(np.float32) + handedness = (handedness - tensorInfo.qpZp) * tensorInfo.qpScale + handedness = handedness[0] # normalize landmarks landmarks /= self.scale_factor - hand_landmarks_msg = create_hand_keypoints_message(landmarks, float(handdedness), float(hand_score), self.score_threshold) + hand_landmarks_msg = create_hand_keypoints_message(landmarks, float(handedness), float(hand_score), self.score_threshold) self.out.send(hand_landmarks_msg) \ No newline at end of file diff --git a/ml/postprocessing/utils/message_creation/keypoints.py b/ml/postprocessing/utils/message_creation/keypoints.py index cbcf6fad..a8bd1482 100644 --- a/ml/postprocessing/utils/message_creation/keypoints.py +++ b/ml/postprocessing/utils/message_creation/keypoints.py @@ -3,18 +3,18 @@ from typing import List from ....messages import HandKeypoints -def create_hand_keypoints_message(hand_keypoints: np.ndarray, handdedness: float, confidence: float, confidence_threshold: float) -> HandKeypoints: +def create_hand_keypoints_message(hand_keypoints: np.ndarray, handedness: float, confidence: float, confidence_threshold: float) -> HandKeypoints: """ - Create a message for the hand keypoint detection. The message contains the 3D coordinates of the detected hand keypoints, handdedness, and confidence score. + Create a message for the hand keypoint detection. The message contains the 3D coordinates of the detected hand keypoints, handedness, and confidence score. Args: hand_keypoints (np.ndarray): Detected hand keypoints of shape (N,3) meaning [...,[x, y, z],...]. - handdedness (float): Handdedness score of the detected hand (left or right). + handedness (float): Handedness score of the detected hand (left or right). confidence (float): Confidence score of the detected hand. confidence_threshold (float): Confidence threshold for the overall hand. Returns: - HandKeypoints: Message containing the 3D coordinates of the detected hand keypoints, handdedness, and confidence score. + HandKeypoints: Message containing the 3D coordinates of the detected hand keypoints, handedness, and confidence score. """ if not isinstance(hand_keypoints, np.ndarray): @@ -23,13 +23,13 @@ def create_hand_keypoints_message(hand_keypoints: np.ndarray, handdedness: float raise ValueError(f"hand_keypoints should be of shape (N,3) meaning [...,[x, y, z],...], got {hand_keypoints.shape}.") if hand_keypoints.shape[1] != 3: raise ValueError(f"hand_keypoints 2nd dimension should be of size 3 e.g. [x, y, z], got {hand_keypoints.shape[1]}.") - if not isinstance(handdedness, float): - raise ValueError(f"handdedness should be float, got {type(handdedness)}.") + if not isinstance(handedness, float): + raise ValueError(f"handedness should be float, got {type(handedness)}.") if not isinstance(confidence, float): raise ValueError(f"confidence should be float, got {type(confidence)}.") hand_keypoints_msg = HandKeypoints() - hand_keypoints_msg.handdedness = handdedness + hand_keypoints_msg.handedness = handedness hand_keypoints_msg.confidence = confidence points = [] if confidence >= confidence_threshold: @@ -39,6 +39,6 @@ def create_hand_keypoints_message(hand_keypoints: np.ndarray, handdedness: float pt.y = hand_keypoints[i][1] pt.z = hand_keypoints[i][2] points.append(pt) - hand_keypoints_msg.landmarks = points + hand_keypoints_msg.keypoints = points return hand_keypoints_msg \ No newline at end of file From e2d053490fdb8064fc5a998401351dbfc3a1fad8 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:18:46 +0200 Subject: [PATCH 27/29] Init file changed. --- ml/postprocessing/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ml/postprocessing/__init__.py b/ml/postprocessing/__init__.py index 5499c8fc..9c9bdba8 100644 --- a/ml/postprocessing/__init__.py +++ b/ml/postprocessing/__init__.py @@ -2,10 +2,18 @@ from .dncnn3 import DnCNN3Parser from .depth_anything import DepthAnythingParser from .yunet import YuNetParser +from .mediapipe_hand_detection import MPHandDetectionParser +from .mediapipe_hand_landmarker import MPHandLandmarkParser +from .scrfd import SCRFDParser +from .segmentation import SegmentationParser __all__ = [ 'ZeroDCEParser', 'DnCNN3Parser', 'DepthAnythingParser', - 'YuNetParser' + 'YuNetParser', + 'MPHandDetectionParser', + 'MPHandLandmarkParser', + 'SCRFDParser', + 'SegmentationParser', ] From f8a48f1d806b3833364a9fdcaec8391bbd627344 Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:16:50 +0200 Subject: [PATCH 28/29] 1 channel in 3rd dim. validation. --- ml/postprocessing/utils/message_creation/depth.py | 2 ++ ml/postprocessing/utils/message_creation/segmentation.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/ml/postprocessing/utils/message_creation/depth.py b/ml/postprocessing/utils/message_creation/depth.py index d52a3907..52ee9fcd 100644 --- a/ml/postprocessing/utils/message_creation/depth.py +++ b/ml/postprocessing/utils/message_creation/depth.py @@ -17,6 +17,8 @@ def create_depth_message(x: np.array) -> dai.ImgFrame: raise ValueError(f"Expected numpy array, got {type(x)}.") if len(x.shape) != 3: raise ValueError(f"Expected 3D input, got {len(x.shape)}D input.") + if x.shape[2] != 1: + raise ValueError(f"Expected 1 channel in the third dimension, got {x.shape[2]} channels.") imgFrame = dai.ImgFrame() imgFrame.setFrame(x) diff --git a/ml/postprocessing/utils/message_creation/segmentation.py b/ml/postprocessing/utils/message_creation/segmentation.py index f08d8eb7..25da189a 100644 --- a/ml/postprocessing/utils/message_creation/segmentation.py +++ b/ml/postprocessing/utils/message_creation/segmentation.py @@ -17,6 +17,8 @@ def create_segmentation_message(x: np.array) -> dai.ImgFrame: raise ValueError(f"Expected numpy array, got {type(x)}.") if len(x.shape) != 3: raise ValueError(f"Expected 3D input, got {len(x.shape)}D input.") + if x.shape[2] != 1: + raise ValueError(f"Expected 1 channel in the third dimension, got {x.shape[2]} channels.") imgFrame = dai.ImgFrame() imgFrame.setFrame(x) From 14a9e4c7f092d91f5cb775a539cd97576c67a21d Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:17:13 +0200 Subject: [PATCH 29/29] Remove unnecessary atributes. --- ml/postprocessing/segmentation.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ml/postprocessing/segmentation.py b/ml/postprocessing/segmentation.py index b30b1b4e..5fd6f25a 100644 --- a/ml/postprocessing/segmentation.py +++ b/ml/postprocessing/segmentation.py @@ -6,22 +6,11 @@ class SegmentationParser(dai.node.ThreadedHostNode): def __init__( self, - threshold=0.5, - num_classes=2, ): dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) - self.threshold = threshold - self.num_classes = num_classes - - def setConfidenceThreshold(self, threshold): - self.threshold = threshold - - def setNumClasses(self, num_classes): - self.num_classes = num_classes - def run(self): """ Postprocessing logic for Segmentation model with `num_classes` classes including background at index 0. @@ -40,7 +29,6 @@ def run(self): segmentation_mask = output.getTensor("output") segmentation_mask = segmentation_mask[0] # num_clases x H x W segmentation_mask = np.vstack((np.zeros((1, segmentation_mask.shape[1], segmentation_mask.shape[2]), dtype=np.float32), segmentation_mask)) - segmentation_mask[segmentation_mask < self.threshold] = 0 overlay_image = np.argmax(segmentation_mask, axis=0).reshape(segmentation_mask.shape[1], segmentation_mask.shape[2], 1).astype(np.uint8) imgFrame = create_segmentation_message(overlay_image)