-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from luxonis/mp_hands
Mp Hands nodes & segmentation msg creation
- Loading branch information
Showing
15 changed files
with
742 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from .img_detections import ImgDetectionWithKeypoints, ImgDetectionsWithKeypoints | ||
from .keypoints import HandKeypoints, Keypoints | ||
|
||
__all__ = [ | ||
"ImgDetectionWithKeypoints", | ||
"ImgDetectionsWithKeypoints", | ||
] | ||
"HandKeypoints", | ||
"Keypoints", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import depthai as dai | ||
from typing import List | ||
|
||
|
||
class Keypoints(dai.Buffer): | ||
def __init__(self): | ||
super().__init__() | ||
self._keypoints: List[dai.Point3f] = [] | ||
|
||
@property | ||
def keypoints(self) -> List[dai.Point3f]: | ||
return self._keypoints | ||
|
||
@keypoints.setter | ||
def keypoints(self, value: List[dai.Point3f]): | ||
if not isinstance(value, list): | ||
raise TypeError("keypoints must be a list.") | ||
for item in value: | ||
if not isinstance(item, dai.Point3f): | ||
raise TypeError("All items in keypoints must be of type dai.Point3f.") | ||
self._keypoints = value | ||
|
||
|
||
class HandKeypoints(Keypoints): | ||
def __init__(self): | ||
Keypoints.__init__(self) | ||
self._confidence: float = 0.0 | ||
self._handdedness: float = 0.0 | ||
|
||
@property | ||
def confidence(self) -> float: | ||
return self._confidence | ||
|
||
@confidence.setter | ||
def confidence(self, value: float): | ||
if not isinstance(value, float): | ||
raise TypeError("confidence must be a float.") | ||
self._confidence = value | ||
|
||
@property | ||
def handdedness(self) -> float: | ||
return self._handdedness | ||
|
||
@handdedness.setter | ||
def handdedness(self, value: float): | ||
if not isinstance(value, float): | ||
raise TypeError("handdedness must be a float.") | ||
self._handdedness = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import depthai as dai | ||
import numpy as np | ||
import cv2 | ||
|
||
from .utils.message_creation import create_detection_message | ||
from .utils.medipipe import generate_anchors_and_decode | ||
|
||
class MPHandDetectionParser(dai.node.ThreadedHostNode): | ||
def __init__( | ||
self, | ||
score_threshold=0.5, | ||
nms_threshold=0.5, | ||
top_k=100 | ||
): | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
self.score_threshold = score_threshold | ||
self.nms_threshold = nms_threshold | ||
self.top_k = top_k | ||
|
||
def setConfidenceThreshold(self, threshold): | ||
self.score_threshold = threshold | ||
|
||
def setNMSThreshold(self, threshold): | ||
self.nms_threshold = threshold | ||
|
||
def setTopK(self, top_k): | ||
self.top_k = top_k | ||
|
||
def run(self): | ||
""" | ||
Postprocessing logic for MediPipe Hand detection model. | ||
Returns: | ||
dai.ImgDetections containing bounding boxes, labels, and confidence scores of detected hands. | ||
""" | ||
|
||
while self.isRunning(): | ||
|
||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException as e: | ||
break # Pipeline was stopped | ||
|
||
tensorInfo = output.getTensorInfo("Identity") | ||
bboxes = output.getTensor(f"Identity").reshape(2016, 18).astype(np.float32) | ||
bboxes = (bboxes - tensorInfo.qpZp) * tensorInfo.qpScale | ||
tensorInfo = output.getTensorInfo("Identity_1") | ||
scores = output.getTensor(f"Identity_1").reshape(2016).astype(np.float32) | ||
scores = (scores - tensorInfo.qpZp) * tensorInfo.qpScale | ||
|
||
decoded_bboxes = generate_anchors_and_decode(bboxes=bboxes, scores=scores, threshold=self.score_threshold, scale=192) | ||
|
||
bboxes = [] | ||
scores = [] | ||
|
||
for hand in decoded_bboxes: | ||
extended_points = hand.rect_points | ||
xmin = int(min(extended_points[0][0], extended_points[1][0])) | ||
ymin = int(min(extended_points[0][1], extended_points[1][1])) | ||
xmax = int(max(extended_points[2][0], extended_points[3][0])) | ||
ymax = int(max(extended_points[2][1], extended_points[3][1])) | ||
|
||
bboxes.append([xmin, ymin, xmax, ymax]) | ||
scores.append(hand.pd_score) | ||
|
||
indices = cv2.dnn.NMSBoxes(bboxes, scores, self.score_threshold, self.nms_threshold, top_k=self.top_k) | ||
bboxes = np.array(bboxes)[indices] | ||
scores = np.array(scores)[indices] | ||
|
||
detections_msg = create_detection_message(bboxes, scores, labels=None) | ||
self.out.send(detections_msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import depthai as dai | ||
import numpy as np | ||
import cv2 | ||
|
||
from .utils.message_creation import create_hand_keypoints_message | ||
|
||
class MPHandLandmarkParser(dai.node.ThreadedHostNode): | ||
def __init__( | ||
self, | ||
score_threshold=0.5, | ||
scale_factor=224 | ||
): | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
self.score_threshold = score_threshold | ||
self.scale_factor = scale_factor | ||
|
||
def setScoreThreshold(self, threshold): | ||
self.score_threshold = threshold | ||
|
||
def setScaleFactor(self, scale_factor): | ||
self.scale_factor = scale_factor | ||
|
||
def run(self): | ||
""" | ||
Postprocessing logic for MediaPipe Hand landmark model. | ||
Returns: | ||
HandLandmarks containing normalized 21 landmarks, confidence score, and handdedness score (right or left hand). | ||
""" | ||
|
||
while self.isRunning(): | ||
|
||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException as e: | ||
break # Pipeline was stopped | ||
|
||
tensorInfo = output.getTensorInfo("Identity") | ||
landmarks = output.getTensor(f"Identity").reshape(21, 3).astype(np.float32) | ||
landmarks = (landmarks - tensorInfo.qpZp) * tensorInfo.qpScale | ||
tensorInfo = output.getTensorInfo("Identity_1") | ||
hand_score = output.getTensor(f"Identity_1").reshape(-1).astype(np.float32) | ||
hand_score = (hand_score - tensorInfo.qpZp) * tensorInfo.qpScale | ||
hand_score = hand_score[0] | ||
tensorInfo = output.getTensorInfo("Identity_2") | ||
handedness = output.getTensor(f"Identity_2").reshape(-1).astype(np.float32) | ||
handedness = (handedness - tensorInfo.qpZp) * tensorInfo.qpScale | ||
handedness = handedness[0] | ||
|
||
# normalize landmarks | ||
landmarks /= self.scale_factor | ||
|
||
hand_landmarks_msg = create_hand_keypoints_message(landmarks, float(handedness), float(hand_score), self.score_threshold) | ||
self.out.send(hand_landmarks_msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import depthai as dai | ||
import numpy as np | ||
import cv2 | ||
from .utils.message_creation import create_segmentation_message | ||
|
||
class SegmentationParser(dai.node.ThreadedHostNode): | ||
def __init__( | ||
self, | ||
): | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
def run(self): | ||
""" | ||
Postprocessing logic for Segmentation model with `num_classes` classes including background at index 0. | ||
Returns: | ||
Segmenation mask with `num_classes` classes, 0 - background. | ||
""" | ||
|
||
while self.isRunning(): | ||
|
||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException as e: | ||
break # Pipeline was stopped | ||
|
||
segmentation_mask = output.getTensor("output") | ||
segmentation_mask = segmentation_mask[0] # num_clases x H x W | ||
segmentation_mask = np.vstack((np.zeros((1, segmentation_mask.shape[1], segmentation_mask.shape[2]), dtype=np.float32), segmentation_mask)) | ||
overlay_image = np.argmax(segmentation_mask, axis=0).reshape(segmentation_mask.shape[1], segmentation_mask.shape[2], 1).astype(np.uint8) | ||
|
||
imgFrame = create_segmentation_message(overlay_image) | ||
self.out.send(imgFrame) |
This file was deleted.
Oops, something went wrong.
File renamed without changes.
Oops, something went wrong.