Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parsers for HRNet and AgeGender models #16

Merged
merged 9 commits into from
Aug 28, 2024
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .img_detections import ImgDetectionsWithKeypoints, ImgDetectionWithKeypoints
from .keypoints import HandKeypoints, Keypoints
from .lines import Line, Lines
from .misc import AgeGender

__all__ = [
"ImgDetectionWithKeypoints",
Expand All @@ -11,4 +12,5 @@
"Line",
"Lines",
"Classifications",
"AgeGender",
]
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .detection import create_detection_message, create_line_detection_message
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .misc import create_age_gender_message
from .segmentation import create_segmentation_message
from .thermal import create_thermal_message
from .tracked_features import create_tracked_features_message
Expand All @@ -18,4 +19,5 @@
"create_keypoints_message",
"create_thermal_message",
"create_classification_message",
"create_age_gender_message",
]
39 changes: 39 additions & 0 deletions depthai_nodes/ml/messages/creators/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

from ...messages import AgeGender, Classifications


def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender:
"""Create a DepthAI message for the age and gender probability.

@param age: Detected person age (must be multiplied by 100 to get years).
@type age: float
@param gender_prob: Detected person gender probability [female, male].
@type gender_prob: List[float]
@return: AgeGender message containing the predicted person's age and Classifications
message containing the classes and probabilities of the predicted gender.
@rtype: AgeGender
@raise ValueError: If age is not a float.
@raise ValueError: If gender_prob is not a list.
@raise ValueError: If each item in gender_prob is not a float.
"""

if not isinstance(age, float):
raise ValueError(f"age should be float, got {type(age)}.")

if not isinstance(gender_prob, List):
raise ValueError(f"gender_prob should be list, got {type(gender_prob)}.")
for item in gender_prob:
if not isinstance(item, float):
raise ValueError(
f"gender_prob list values must be of type float, instead got {type(item)}."
)

age_gender_message = AgeGender()
age_gender_message.age = age
gender = Classifications()
gender.classes = ["female", "male"]
gender.scores = gender_prob
age_gender_message.gender = gender

return age_gender_message
34 changes: 34 additions & 0 deletions depthai_nodes/ml/messages/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import depthai as dai

from ..messages import Classifications


class AgeGender(dai.Buffer):
def __init__(self):
super().__init__()
self._age: float = None
self._gender = Classifications()

@property
def age(self) -> float:
return self._age

@age.setter
def age(self, value: float):
if not isinstance(value, float):
raise TypeError(
f"start_point must be of type float, instead got {type(value)}."
)
self._age = value

@property
def gender(self) -> Classifications:
return self._gender

@gender.setter
def gender(self, value: Classifications):
if not isinstance(value, Classifications):
raise TypeError(
f"gender must be of type Classifications, instead got {type(value)}."
)
self._gender = value
4 changes: 4 additions & 0 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .age_gender import AgeGenderParser
from .classification import ClassificationParser
from .hrnet import HRNetParser
from .image_output import ImageOutputParser
from .keypoints import KeypointParser
from .mediapipe_hand_landmarker import MPHandLandmarkParser
Expand Down Expand Up @@ -26,4 +28,6 @@
"XFeatParser",
"ThermalImageParser",
"ClassificationParser",
"AgeGenderParser",
"HRNetParser",
]
43 changes: 43 additions & 0 deletions depthai_nodes/ml/parsers/age_gender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import depthai as dai

from ..messages.creators import create_age_gender_message


class AgeGenderParser(dai.node.ThreadedHostNode):
"""Parser class for parsing the output of the Age-Gender regression model.

Attributes
----------
input : Node.Input
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node.
out : Node.Output
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved.

Output Message/s
----------------
**Type**: AgeGender

**Description**: Message containing the detected person age and Classfications object for storing information about the detected person's gender.
"""

def __init__(self):
"""Initializes the AgeGenderParser node."""
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If going with the proposed change add parameter for thresholding?


def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

age = output.getTensor("age_conv3", dequantize=True).item()
age *= 100 # convert to years
prob = output.getTensor("prob", dequantize=True).flatten().tolist()

age_gender_message = create_age_gender_message(age=age, gender_prob=prob)
age_gender_message.setTimestamp(output.getTimestamp())

self.out.send(age_gender_message)
80 changes: 80 additions & 0 deletions depthai_nodes/ml/parsers/hrnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import depthai as dai
import numpy as np

from ..messages.creators import create_keypoints_message


class HRNetParser(dai.node.ThreadedHostNode):
"""Parser class for parsing the output of the HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation.

Attributes
----------
input : Node.Input
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node.
out : Node.Output
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved.
score_threshold : float
Confidence score threshold for detected keypoints.

Output Message/s
----------------
**Type**: Keypoints

**Description**: Keypoints message containing detected body keypoints.
"""

def __init__(self, score_threshold=0.5):
"""Initializes the HRNetParser node.

@param score_threshold: Confidence score threshold for detected keypoints.
@type score_threshold: float
"""
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.score_threshold = score_threshold

def setScoreThreshold(self, threshold):
"""Sets the confidence score threshold for the detected body keypoints.

@param threshold: Confidence score threshold for detected keypoints.
@type threshold: float
"""
self.score_threshold = threshold

def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

heatmaps = output.getTensor("heatmaps", dequantize=True)

if len(heatmaps.shape) == 4:
heatmaps = heatmaps[0]
if heatmaps.shape[2] == 16: # HW_ instead of _HW
heatmaps = heatmaps.transpose(2, 0, 1)
_, map_h, map_w = heatmaps.shape

scores = np.array([np.max(heatmap) for heatmap in heatmaps])
keypoints = np.array(
[
np.unravel_index(heatmap.argmax(), heatmap.shape)
for heatmap in heatmaps
]
)
keypoints = keypoints.astype(np.float32)
keypoints = keypoints[:, ::-1] / np.array(
[map_w, map_h]
) # normalize keypoints to [0, 1]

keypoints_message = create_keypoints_message(
keypoints=keypoints,
scores=scores,
confidence_threshold=self.score_threshold,
)
keypoints_message.setTimestamp(output.getTimestamp())

self.out.send(keypoints_message)