-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add parsers for HRNet and AgeGender models #16
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
33cc8a5
feat: add support for age_gender model
jkbmrz f88b5dc
feat: add support for HRNet model
jkbmrz 4c76f57
fix: formatting and structure
jkbmrz 610ea62
fix: AgeGenderParser formatting and convert age to years
jkbmrz 237ae46
fix: HRNetParser formatting, remove comments, add normalization
jkbmrz 735d85a
fix: add timestamps to outgoing messages
jkbmrz 8a06072
Pre-commit fix.
kkeroo 36b32e5
Add Classifications msg to AgeGender.
kkeroo 960512d
Docstrings fix.
kkeroo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import List | ||
|
||
from ...messages import AgeGender, Classifications | ||
|
||
|
||
def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender: | ||
"""Create a DepthAI message for the age and gender probability. | ||
|
||
@param age: Detected person age (must be multiplied by 100 to get years). | ||
@type age: float | ||
@param gender_prob: Detected person gender probability [female, male]. | ||
@type gender_prob: List[float] | ||
@return: AgeGender message containing the predicted person's age and Classifications | ||
message containing the classes and probabilities of the predicted gender. | ||
@rtype: AgeGender | ||
@raise ValueError: If age is not a float. | ||
@raise ValueError: If gender_prob is not a list. | ||
@raise ValueError: If each item in gender_prob is not a float. | ||
""" | ||
|
||
if not isinstance(age, float): | ||
raise ValueError(f"age should be float, got {type(age)}.") | ||
|
||
if not isinstance(gender_prob, List): | ||
raise ValueError(f"gender_prob should be list, got {type(gender_prob)}.") | ||
for item in gender_prob: | ||
if not isinstance(item, float): | ||
raise ValueError( | ||
f"gender_prob list values must be of type float, instead got {type(item)}." | ||
) | ||
|
||
age_gender_message = AgeGender() | ||
age_gender_message.age = age | ||
gender = Classifications() | ||
gender.classes = ["female", "male"] | ||
gender.scores = gender_prob | ||
age_gender_message.gender = gender | ||
|
||
return age_gender_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import depthai as dai | ||
|
||
from ..messages import Classifications | ||
|
||
|
||
class AgeGender(dai.Buffer): | ||
def __init__(self): | ||
super().__init__() | ||
self._age: float = None | ||
self._gender = Classifications() | ||
|
||
@property | ||
def age(self) -> float: | ||
return self._age | ||
|
||
@age.setter | ||
def age(self, value: float): | ||
if not isinstance(value, float): | ||
raise TypeError( | ||
f"start_point must be of type float, instead got {type(value)}." | ||
) | ||
self._age = value | ||
|
||
@property | ||
def gender(self) -> Classifications: | ||
return self._gender | ||
|
||
@gender.setter | ||
def gender(self, value: Classifications): | ||
if not isinstance(value, Classifications): | ||
raise TypeError( | ||
f"gender must be of type Classifications, instead got {type(value)}." | ||
) | ||
self._gender = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import depthai as dai | ||
|
||
from ..messages.creators import create_age_gender_message | ||
|
||
|
||
class AgeGenderParser(dai.node.ThreadedHostNode): | ||
"""Parser class for parsing the output of the Age-Gender regression model. | ||
|
||
Attributes | ||
---------- | ||
input : Node.Input | ||
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. | ||
out : Node.Output | ||
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. | ||
|
||
Output Message/s | ||
---------------- | ||
**Type**: AgeGender | ||
|
||
**Description**: Message containing the detected person age and Classfications object for storing information about the detected person's gender. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initializes the AgeGenderParser node.""" | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
def run(self): | ||
while self.isRunning(): | ||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException: | ||
break # Pipeline was stopped | ||
|
||
age = output.getTensor("age_conv3", dequantize=True).item() | ||
age *= 100 # convert to years | ||
prob = output.getTensor("prob", dequantize=True).flatten().tolist() | ||
|
||
age_gender_message = create_age_gender_message(age=age, gender_prob=prob) | ||
age_gender_message.setTimestamp(output.getTimestamp()) | ||
|
||
self.out.send(age_gender_message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import depthai as dai | ||
import numpy as np | ||
|
||
from ..messages.creators import create_keypoints_message | ||
|
||
|
||
class HRNetParser(dai.node.ThreadedHostNode): | ||
"""Parser class for parsing the output of the HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation. | ||
|
||
Attributes | ||
---------- | ||
input : Node.Input | ||
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. | ||
out : Node.Output | ||
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. | ||
score_threshold : float | ||
Confidence score threshold for detected keypoints. | ||
|
||
Output Message/s | ||
---------------- | ||
**Type**: Keypoints | ||
|
||
**Description**: Keypoints message containing detected body keypoints. | ||
""" | ||
|
||
def __init__(self, score_threshold=0.5): | ||
"""Initializes the HRNetParser node. | ||
|
||
@param score_threshold: Confidence score threshold for detected keypoints. | ||
@type score_threshold: float | ||
""" | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
self.score_threshold = score_threshold | ||
|
||
def setScoreThreshold(self, threshold): | ||
"""Sets the confidence score threshold for the detected body keypoints. | ||
|
||
@param threshold: Confidence score threshold for detected keypoints. | ||
@type threshold: float | ||
""" | ||
self.score_threshold = threshold | ||
|
||
def run(self): | ||
while self.isRunning(): | ||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException: | ||
break # Pipeline was stopped | ||
|
||
heatmaps = output.getTensor("heatmaps", dequantize=True) | ||
|
||
if len(heatmaps.shape) == 4: | ||
heatmaps = heatmaps[0] | ||
if heatmaps.shape[2] == 16: # HW_ instead of _HW | ||
heatmaps = heatmaps.transpose(2, 0, 1) | ||
_, map_h, map_w = heatmaps.shape | ||
|
||
scores = np.array([np.max(heatmap) for heatmap in heatmaps]) | ||
keypoints = np.array( | ||
[ | ||
np.unravel_index(heatmap.argmax(), heatmap.shape) | ||
for heatmap in heatmaps | ||
] | ||
) | ||
keypoints = keypoints.astype(np.float32) | ||
keypoints = keypoints[:, ::-1] / np.array( | ||
[map_w, map_h] | ||
) # normalize keypoints to [0, 1] | ||
|
||
keypoints_message = create_keypoints_message( | ||
keypoints=keypoints, | ||
scores=scores, | ||
confidence_threshold=self.score_threshold, | ||
) | ||
keypoints_message.setTimestamp(output.getTimestamp()) | ||
|
||
self.out.send(keypoints_message) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If going with the proposed change add parameter for thresholding?