Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mp Hands nodes & segmentation msg creation #2

Merged
merged 30 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
38be7e8
New import path.
kkeroo May 28, 2024
c7a566f
Segmentation msg creation and new selfi seg. output.
kkeroo May 28, 2024
e7dc367
Nodes for MP hands.
kkeroo May 28, 2024
eac159c
More general function for depth or segmentation message creation.
kkeroo May 28, 2024
65c8b75
Typo.
kkeroo May 28, 2024
869b607
Adding HandLandmarksDescriptor.
kkeroo May 29, 2024
3434ca3
Removing unnecessary variables.
kkeroo May 29, 2024
d796a21
Renaming mp hands nodes.
kkeroo May 29, 2024
26a53e4
Rename MP selfie segmentation node.
kkeroo May 29, 2024
5437ca8
HandKeypoints message using property.
kkeroo May 30, 2024
9a486bc
Depth and Segmentation msg creation.
kkeroo May 30, 2024
c084392
Segmentation msg added to host node.
kkeroo May 30, 2024
b1a5ce0
Removed unnecessary prints.
kkeroo May 30, 2024
d7171e8
Renamed mediapipe.
kkeroo May 30, 2024
df7a38b
Normalize hand landmarks.
kkeroo May 30, 2024
8d6f6b6
Anchors and decoding in utils.
kkeroo May 31, 2024
3e90e4b
Mediapipe metadata added.
kkeroo May 31, 2024
bf463a2
general multiclass segmentation parser.
kkeroo May 31, 2024
26381fa
Validating input in create functions.
kkeroo Jun 3, 2024
2deaeee
Function for creating detection and keypoint msgs.
kkeroo Jun 3, 2024
293ce23
Moved anchors function.
kkeroo Jun 3, 2024
45131f7
argmax to get segmentation classes.
kkeroo Jun 3, 2024
827d5c8
Updated docs.
kkeroo Jun 3, 2024
25ce9cc
Keypoints message.
kkeroo Jun 5, 2024
47b8dd7
Generic keypoints.
kkeroo Jun 6, 2024
847165f
Correction: handdedness -> handedness.
kkeroo Jun 6, 2024
e2d0534
Init file changed.
kkeroo Jun 6, 2024
f8a48f1
1 channel in 3rd dim. validation.
kkeroo Jun 6, 2024
14a9e4c
Remove unnecessary atributes.
kkeroo Jun 6, 2024
7d9f6d4
Merge branch 'main' into mp_hands
kkeroo Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ml/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .img_detections import ImgDetectionsWithKeypoints
from .img_detections import ImgDetectionsWithKeypoints
from .hand_keypoints import HandKeypoints
21 changes: 21 additions & 0 deletions ml/messages/hand_keypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import depthai as dai
from typing import List

class HandKeypoints(dai.Buffer):
def __init__(self):
dai.Buffer.__init__(self)
self.confidence: float = 0.0
self.handdedness: float = 0.0
self._keypoints: List[dai.Point3f] = []

@property
def keypoints(self) -> List[dai.Point3f]:
return self._keypoints

@keypoints.setter
def keypoints(self, value: List[dai.Point3f]):
if not isinstance(value, list):
raise TypeError("keypoints must be a list.")
for item in value:
if not isinstance(item, dai.Point3f):
raise TypeError("All items in keypoints must be of type dai.Point3f.")
86 changes: 86 additions & 0 deletions ml/postprocessing/mediapipe_hand_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import depthai as dai
import numpy as np
import cv2

from .utils.detection import generate_anchors_and_decode

class MPHandDetectionParser(dai.node.ThreadedHostNode):
def __init__(
self,
score_threshold=0.5,
nms_threshold=0.5,
top_k=100
):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.top_k = top_k

def setConfidenceThreshold(self, threshold):
self.score_threshold = threshold

def setNMSThreshold(self, threshold):
self.nms_threshold = threshold

def setTopK(self, top_k):
self.top_k = top_k

def run(self):
"""
Postprocessing logic for MediPipe Hand detection model.

Returns:
dai.ImgDetections containing bounding boxes, labels, and confidence scores of detected hands.
"""

while self.isRunning():

try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException as e:
break # Pipeline was stopped

tensorInfo = output.getTensorInfo("Identity")
bboxes = output.getTensor(f"Identity").reshape(2016, 18).astype(np.float32)
bboxes = (bboxes - tensorInfo.qpZp) * tensorInfo.qpScale
tensorInfo = output.getTensorInfo("Identity_1")
scores = output.getTensor(f"Identity_1").reshape(2016).astype(np.float32)
scores = (scores - tensorInfo.qpZp) * tensorInfo.qpScale

decoded_bboxes = generate_anchors_and_decode(bboxes=bboxes, scores=scores, threshold=self.score_threshold, scale=192)

bboxes = []
scores = []

for hand in decoded_bboxes:
extended_points = hand.rect_points
xmin = int(min(extended_points[0][0], extended_points[1][0]))
ymin = int(min(extended_points[0][1], extended_points[1][1]))
xmax = int(max(extended_points[2][0], extended_points[3][0]))
ymax = int(max(extended_points[2][1], extended_points[3][1]))

bboxes.append([xmin, ymin, xmax, ymax])
scores.append(hand.pd_score)

indices = cv2.dnn.NMSBoxes(bboxes, scores, self.score_threshold, self.nms_threshold, top_k=self.top_k)
bboxes = np.array(bboxes)[indices]
scores = np.array(scores)[indices]

detections = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be in a create function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function added here.

for bbox, score in zip(bboxes, scores):
detection = dai.ImgDetection()
detection.confidence = score
detection.label = 0
detection.xmin = bbox[0]
detection.ymin = bbox[1]
detection.xmax = bbox[2]
detection.ymax = bbox[3]
detections.append(detection)

detections_msg = dai.ImgDetections()
detections_msg.detections = detections

self.out.send(detections_msg)
68 changes: 68 additions & 0 deletions ml/postprocessing/mediapipe_hand_landmarker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import depthai as dai
import numpy as np
import cv2

from ..messages import HandKeypoints

class MPHandLandmarkParser(dai.node.ThreadedHostNode):
def __init__(
self,
score_threshold=0.5,
scale_factor=224
):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.score_threshold = score_threshold
self.scale_factor = scale_factor

def setScoreThreshold(self, threshold):
self.score_threshold = threshold

def setScaleFactor(self, scale_factor):
self.scale_factor = scale_factor

def run(self):
"""
Postprocessing logic for MediaPipe Hand landmark model.

Returns:
HandLandmarks containing normalized 21 landmarks, confidence score, and handdedness score (right or left hand).
"""

while self.isRunning():

try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException as e:
break # Pipeline was stopped

tensorInfo = output.getTensorInfo("Identity")
landmarks = output.getTensor(f"Identity").reshape(21, 3).astype(np.float32)
landmarks = (landmarks - tensorInfo.qpZp) * tensorInfo.qpScale
tensorInfo = output.getTensorInfo("Identity_1")
hand_score = output.getTensor(f"Identity_1").reshape(-1).astype(np.float32)
hand_score = (hand_score - tensorInfo.qpZp) * tensorInfo.qpScale
hand_score = hand_score[0]
tensorInfo = output.getTensorInfo("Identity_2")
handdedness = output.getTensor(f"Identity_2").reshape(-1).astype(np.float32)
handdedness = (handdedness - tensorInfo.qpZp) * tensorInfo.qpScale
handdedness = handdedness[0]

# normalize landmarks
landmarks /= self.scale_factor

hand_landmarks_msg = HandKeypoints()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_message function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function added here.

hand_landmarks_msg.handdedness = handdedness
hand_landmarks_msg.confidence = hand_score
hand_landmarks = []
if hand_score >= self.score_threshold:
for i in range(21):
pt = dai.Point3f()
pt.x = landmarks[i][0]
pt.y = landmarks[i][1]
pt.z = landmarks[i][2]
hand_landmarks.append(pt)
hand_landmarks_msg.landmarks = hand_landmarks
self.out.send(hand_landmarks_msg)
9 changes: 2 additions & 7 deletions ml/postprocessing/scrfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@
import numpy as np
import cv2

from ..custom_messages.img_detections import ImgDetectionsWithKeypoints
from ..messages import ImgDetectionsWithKeypoints

class SCRFDParser(dai.node.ThreadedHostNode):
def __init__(
self,
score_threshold=0.5,
nms_threshold=0.5,
top_k=100,
input_size=(640, 640), # WH
top_k=100
):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.input_size = input_size
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.top_k = top_k
Expand All @@ -30,9 +28,6 @@ def setNMSThreshold(self, threshold):
def setTopK(self, top_k):
self.top_k = top_k

def setInputSize(self, width, height):
self.input_size = (width, height)

def run(self):
"""
Postprocessing logic for SCRFD model.
Expand Down
49 changes: 49 additions & 0 deletions ml/postprocessing/segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import depthai as dai
import numpy as np
import cv2
from .utils.message_creation import create_segmentation_message

class SegmentationParser(dai.node.ThreadedHostNode):
def __init__(
self,
threshold=0.5,
num_classes=2,
):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.threshold = threshold
self.num_classes = num_classes

def setConfidenceThreshold(self, threshold):
self.threshold = threshold

def setNumClasses(self, num_classes):
self.num_classes = num_classes

def run(self):
"""
Postprocessing logic for Segmentation model with `num_classes` classes including background at index 0.

Returns:
Segmenation mask with `num_classes` classes, 0 - background.
"""

while self.isRunning():

try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException as e:
break # Pipeline was stopped

segmentation_mask = output.getTensor("output")
segmentation_mask = segmentation_mask[0] # num_clases x H x W
overlay_image = np.zeros((segmentation_mask.shape[1], segmentation_mask.shape[2], 1), dtype=np.uint8)

for class_id in range(self.num_classes-1):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt overlay_image be just an argmax over segmentation mask?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, argmax added here. One 'layer' is added in the first dimension with all zeros so first class has index 1.

class_mask = segmentation_mask[class_id] > self.threshold
overlay_image[class_mask] = class_id + 1

imgFrame = create_segmentation_message(overlay_image)
self.out.send(imgFrame)
58 changes: 0 additions & 58 deletions ml/postprocessing/selfie_seg.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mediapipe_selfie_segmentation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

This file was deleted.

11 changes: 11 additions & 0 deletions ml/postprocessing/utils/detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .medipipe import generate_handtracker_anchors, decode_bboxes, detections_to_rect, rect_transformation

def generate_anchors_and_decode(bboxes, scores, threshold=0.5, scale=192):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this only contains mediapipe functions, I don't think it belongs here.

When we adapt mediapipe utils and make it more generic, we can move certain parts here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here

"""
Generate anchors and decode bounding boxes for mediapipe hand detection model.
"""
anchors = generate_handtracker_anchors(scale, scale)
decoded_bboxes = decode_bboxes(threshold, scores, bboxes, anchors, scale=scale)
detections_to_rect(decoded_bboxes)
rect_transformation(decoded_bboxes, scale, scale)
return decoded_bboxes
Loading