Skip to content

Commit

Permalink
New model parsers.
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeroo authored Jun 28, 2024
2 parents 83b51c2 + e50e75a commit 0c0edd6
Show file tree
Hide file tree
Showing 17 changed files with 770 additions and 35 deletions.
3 changes: 3 additions & 0 deletions ml/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .img_detections import ImgDetectionWithKeypoints, ImgDetectionsWithKeypoints
from .keypoints import HandKeypoints, Keypoints
from .lines import Line, Lines

__all__ = [
"ImgDetectionWithKeypoints",
"ImgDetectionsWithKeypoints",
"HandKeypoints",
"Keypoints",
"Line",
"Lines",
]
8 changes: 6 additions & 2 deletions ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .image import create_image_message
from .segmentation import create_segmentation_message
from .keypoints import create_hand_keypoints_message
from .detection import create_detection_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .detection import create_detection_message, create_line_detection_message
from .tracked_features import create_tracked_features_message
from .depth import create_depth_message

__all__ = [
Expand All @@ -10,4 +11,7 @@
"create_hand_keypoints_message",
"create_detection_message",
"create_depth_message",
"create_line_detection_message",
"create_tracked_features_message",
"create_keypoints_message",
]
50 changes: 50 additions & 0 deletions ml/messages/creators/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ...messages import (
ImgDetectionWithKeypoints,
ImgDetectionsWithKeypoints,
Line,
Lines,
)


Expand Down Expand Up @@ -113,3 +115,51 @@ def create_detection_message(
detections_msg = img_detections()
detections_msg.detections = detections
return detections_msg

def create_line_detection_message(lines: np.ndarray, scores: np.ndarray):
"""
Create a message for the line detection. The message contains the lines and confidence scores of detected lines.
Args:
lines (np.ndarray): Detected lines of shape (N,4) meaning [...,[x_start, y_start, x_end, y_end],...].
scores (np.ndarray): Confidence scores of detected lines of shape (N,).
Returns:
dai.Lines: Message containing the lines and confidence scores of detected lines.
"""

# checks for lines
if not isinstance(lines, np.ndarray):
raise ValueError(f"lines should be numpy array, got {type(lines)}.")
if len(lines) != 0:
if len(lines.shape) != 2:
raise ValueError(
f"lines should be of shape (N,4) meaning [...,[x_start, y_start, x_end, y_end],...], got {lines.shape}."
)
if lines.shape[1] != 4:
raise ValueError(
f"lines 2nd dimension should be of size 4 e.g. [x_start, y_start, x_end, y_end] got {lines.shape[1]}."
)

# checks for scores
if not isinstance(scores, np.ndarray):
raise ValueError(f"scores should be numpy array, got {type(scores)}.")
if len(scores) != 0:
if len(scores.shape) != 1:
raise ValueError(f"scores should be of shape (N,) meaning, got {scores.shape}.")
if scores.shape[0] != lines.shape[0]:
raise ValueError(
f"scores should have same length as lines, got {scores.shape[0]} and {lines.shape[0]}."
)

line_detections = []
for i, line in enumerate(lines):
line_detection = Line()
line_detection.start_point = dai.Point2f(line[0], line[1])
line_detection.end_point = dai.Point2f(line[2], line[3])
line_detection.confidence = float(scores[i])
line_detections.append(line_detection)

lines_msg = Lines()
lines_msg.lines = line_detections
return lines_msg
72 changes: 69 additions & 3 deletions ml/messages/creators/keypoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import depthai as dai
import numpy as np
from typing import List
from ...messages import HandKeypoints
from typing import List, Union
from ...messages import HandKeypoints, Keypoints

def create_hand_keypoints_message(hand_keypoints: np.ndarray, handedness: float, confidence: float, confidence_threshold: float) -> HandKeypoints:
"""
Expand Down Expand Up @@ -41,4 +41,70 @@ def create_hand_keypoints_message(hand_keypoints: np.ndarray, handedness: float,
points.append(pt)
hand_keypoints_msg.keypoints = points

return hand_keypoints_msg
return hand_keypoints_msg

def create_keypoints_message(keypoints: Union[np.ndarray, List[List[float]]], scores: Union[np.ndarray, List[float]] = None, confidence_threshold: float = None) -> Keypoints:
"""
Create a message for the keypoints. The message contains 2D or 3D coordinates of the detected keypoints.
Args:
keypoints (np.ndarray OR List[List[float]]): Detected keypoints of shape (N,2 or 3) meaning [...,[x, y],...] or [...,[x, y, z],...].
scores (np.ndarray or List[float]): Confidence scores of the detected keypoints.
confidence_threshold (float): Confidence threshold for the keypoints.
Returns:
Keypoints: Message containing 2D or 3D coordinates of the detected keypoints.
"""

if not isinstance(keypoints, np.ndarray):
if not isinstance(keypoints, list):
raise ValueError(f"keypoints should be numpy array or list, got {type(keypoints)}.")
for keypoint in keypoints:
if not isinstance(keypoint, list):
raise ValueError(f"keypoints should be list of lists or np.array, got list of {type(keypoint)}.")
if len(keypoint) not in [2, 3]:
raise ValueError(f"keypoints inner list should be of size 2 or 3 e.g. [x, y] or [x, y, z], got {len(keypoint)}.")
for coord in keypoint:
if not isinstance(coord, (float)):
raise ValueError(f"keypoints inner list should contain only float, got {type(coord)}.")
keypoints = np.array(keypoints)
if len(keypoints.shape) != 2:
raise ValueError(f"keypoints should be of shape (N,2 or 3) got {keypoints.shape}.")
if keypoints.shape[1] not in [2, 3]:
raise ValueError(f"keypoints 2nd dimension should be of size 2 or 3 e.g. [x, y] or [x, y, z], got {keypoints.shape[1]}.")
if scores is not None:
if not isinstance(scores, np.ndarray):
if not isinstance(scores, list):
raise ValueError(f"scores should be numpy array or list, got {type(scores)}.")
for score in scores:
if not isinstance(score, float):
raise ValueError(f"scores should be list of floats or np.array, got list of {type(score)}.")
scores = np.array(scores)
if len(scores.shape) != 1:
raise ValueError(f"scores should be of shape (N,) meaning [...,score,...], got {scores.shape}.")
if keypoints.shape[0] != scores.shape[0]:
raise ValueError(f"keypoints and scores should have the same length, got {keypoints.shape[0]} and {scores.shape[0]}.")
if confidence_threshold is None:
raise ValueError(f"confidence_threshold should be provided when scores are provided.")
if confidence_threshold is not None:
if not isinstance(confidence_threshold, float):
raise ValueError(f"confidence_threshold should be float, got {type(confidence_threshold)}.")
if scores is None:
raise ValueError(f"confidence_threshold should be provided when scores are provided.")

use_3d = keypoints.shape[1] == 3

keypoints_msg = Keypoints()
points = []
for i, keypoint in enumerate(keypoints):
if scores is not None:
if scores[i] < confidence_threshold:
continue
pt = dai.Point3f()
pt.x = keypoint[0]
pt.y = keypoint[1]
pt.z = keypoint[2] if use_3d else 0
points.append(pt)

keypoints_msg.keypoints = points
return keypoints_msg
63 changes: 63 additions & 0 deletions ml/messages/creators/tracked_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import depthai as dai

def create_feature_point(x: float, y: float, id: int, age: int) -> dai.TrackedFeature:
"""
Create a tracked feature point.
Args:
x (float): X coordinate of the feature point.
y (float): Y coordinate of the feature point.
id (int): ID of the feature point.
age (int): Age of the feature point.
Returns:
dai.TrackedFeature: Tracked feature point.
"""

feature = dai.TrackedFeature()
feature.position.x = x
feature.position.y = y
feature.id = id
feature.age = age

return feature

def create_tracked_features_message(reference_points: np.ndarray, target_points: np.ndarray) -> dai.TrackedFeatures:
"""
Create a message for the tracked features.
Args:
reference_points (np.ndarray): Reference points of shape (N,2) meaning [...,[x, y],...].
target_points (np.ndarray): Target points of shape (N,2) meaning [...,[x, y],...].
Returns:
dai.TrackedFeatures: Message containing the tracked features.
"""


if not isinstance(reference_points, np.ndarray):
raise ValueError(f"reference_points should be numpy array, got {type(reference_points)}.")
if len(reference_points.shape) != 2:
raise ValueError(f"reference_points should be of shape (N,2) meaning [...,[x, y],...], got {reference_points.shape}.")
if reference_points.shape[1] != 2:
raise ValueError(f"reference_points 2nd dimension should be of size 2 e.g. [x, y], got {reference_points.shape[1]}.")
if not isinstance(target_points, np.ndarray):
raise ValueError(f"target_points should be numpy array, got {type(target_points)}.")
if len(target_points.shape) != 2:
raise ValueError(f"target_points should be of shape (N,2) meaning [...,[x, y],...], got {target_points.shape}.")
if target_points.shape[1] != 2:
raise ValueError(f"target_points 2nd dimension should be of size 2 e.g. [x, y], got {target_points.shape[1]}.")

features = []

for i in range(reference_points.shape[0]):
reference_feature = create_feature_point(reference_points[i][0], reference_points[i][1], i, 0)
target_feature = create_feature_point(target_points[i][0], target_points[i][1], i, 1)
features.append(reference_feature)
features.append(target_feature)

features_msg = dai.TrackedFeatures()
features_msg.trackedFeatures = features

return features_msg
58 changes: 58 additions & 0 deletions ml/messages/lines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import depthai as dai
from typing import List

class Line(dai.Buffer):
def __init__(self):
super().__init__()
self._start_point: dai.Point2f = None
self._end_point: dai.Point2f = None
self._confidence: float = None

@property
def start_point(self) -> dai.Point2f:
return self._start_point

@start_point.setter
def start_point(self, value: dai.Point2f):
if not isinstance(value, dai.Point2f):
raise TypeError(f"start_point must be of type Point2f, instead got {type(value)}.")
self._start_point = value

@property
def end_point(self) -> dai.Point2f:
return self._end_point

@end_point.setter
def end_point(self, value: dai.Point2f):
if not isinstance(value, dai.Point2f):
raise TypeError(f"end_point must be of type Point2f, instead got {type(value)}.")
self._end_point = value

@property
def confidence(self) -> float:
return self._confidence

@confidence.setter
def confidence(self, value: float):
if not isinstance(value, float):
raise TypeError(f"confidence must be of type float, instead got {type(value)}.")
self._confidence = value


class Lines(dai.Buffer):
def __init__(self):
super().__init__()
self._lines: List[Line] = []

@property
def lines(self) -> List[Line]:
return self._lines

@lines.setter
def lines(self, value: List[Line]):
if not isinstance(value, List):
raise TypeError(f"lines must be of type List[Line], instead got {type(value)}.")
for line in value:
if not isinstance(line, Line):
raise TypeError(f"lines must be of type List[Line], instead got {type(value)}.")
self._lines = value
10 changes: 9 additions & 1 deletion ml/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from .mediapipe_hand_landmarker import MPHandLandmarkParser
from .scrfd import SCRFDParser
from .segmentation import SegmentationParser
from .superanimal_landmarker import SuperAnimalParser
from .keypoints import KeypointParser
from .mlsd import MLSDParser
from .xfeat import XFeatParser

__all__ = [
'ImageOutputParser',
Expand All @@ -14,4 +18,8 @@
'MPHandLandmarkParser',
'SCRFDParser',
'SegmentationParser',
]
'SuperAnimalParser',
'KeypointParser',
'MLSDParser',
'XFeatParser',
]
60 changes: 60 additions & 0 deletions ml/postprocessing/keypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import depthai as dai
import numpy as np

from ..messages.creators import create_keypoints_message

class KeypointParser(dai.node.ThreadedHostNode):
def __init__(
self,
scale_factor=1,
num_keypoints=None,
):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.scale_factor = scale_factor
self.num_keypoints = num_keypoints

def setScaleFactor(self, scale_factor):
self.scale_factor = scale_factor

def setNumKeypoints(self, num_keypoints):
self.num_keypoints = num_keypoints

def run(self):
"""
Postprocessing logic for Keypoint model.
Returns:
dai.Keypoints: num_keypoints keypoints (2D or 3D).
"""

if self.num_keypoints is None:
raise ValueError("Number of keypoints must be specified!")

while self.isRunning():

try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException as e:
break # Pipeline was stopped

output_layer_names = output.getAllLayerNames()

if len(output_layer_names) != 1:
raise ValueError(f"Expected 1 output layer, got {len(output_layer_names)}.")

keypoints = output.getTensor(output_layer_names[0])
num_coords = int(np.prod(keypoints.shape) / self.num_keypoints)

if num_coords not in [2, 3]:
raise ValueError(f"Expected 2 or 3 coordinates per keypoint, got {num_coords}.")

keypoints = keypoints.reshape(self.num_keypoints, num_coords)

keypoints /= self.scale_factor

msg = create_keypoints_message(keypoints)

self.out.send(msg)
4 changes: 0 additions & 4 deletions ml/postprocessing/mediapipe_hand_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ def run(self):
except dai.MessageQueue.QueueException as e:
break # Pipeline was stopped

tensorInfo = output.getTensorInfo("Identity")
bboxes = output.getTensor(f"Identity").reshape(2016, 18).astype(np.float32)
bboxes = (bboxes - tensorInfo.qpZp) * tensorInfo.qpScale
tensorInfo = output.getTensorInfo("Identity_1")
scores = output.getTensor(f"Identity_1").reshape(2016).astype(np.float32)
scores = (scores - tensorInfo.qpZp) * tensorInfo.qpScale

decoded_bboxes = generate_anchors_and_decode(bboxes=bboxes, scores=scores, threshold=self.score_threshold, scale=192)

Expand Down
Loading

0 comments on commit 0c0edd6

Please sign in to comment.