From 97579859b27a9629bcbcdfdf823141d735a71121 Mon Sep 17 00:00:00 2001 From: aljazkonec1 Date: Mon, 16 Sep 2024 21:09:40 +0200 Subject: [PATCH 1/3] Refactor code. --- .../ml/messages/creators/__init__.py | 9 +- .../ml/messages/creators/classification.py | 163 +++++++++++++++++- .../creators/classification_sequence.py | 116 ------------- depthai_nodes/ml/messages/creators/misc.py | 65 +------ depthai_nodes/ml/parsers/__init__.py | 3 +- depthai_nodes/ml/parsers/classification.py | 61 ++++++- depthai_nodes/ml/parsers/ppdet.py | 2 +- .../ml/parsers/vehicle_attributes.py | 61 ------- .../test_classification_sequence.py | 2 +- tests/unittests/test_creators/test_misc.py | 15 +- .../test_creators/test_vehicle_attributes.py | 2 +- 11 files changed, 247 insertions(+), 252 deletions(-) delete mode 100644 depthai_nodes/ml/messages/creators/classification_sequence.py delete mode 100644 depthai_nodes/ml/parsers/vehicle_attributes.py diff --git a/depthai_nodes/ml/messages/creators/__init__.py b/depthai_nodes/ml/messages/creators/__init__.py index 8656ffa2..17cdaf0c 100644 --- a/depthai_nodes/ml/messages/creators/__init__.py +++ b/depthai_nodes/ml/messages/creators/__init__.py @@ -1,11 +1,14 @@ -from .classification import create_classification_message -from .classification_sequence import create_classification_sequence_message +from .classification import ( + create_classification_message, + create_classification_sequence_message, + create_multi_classification_message, +) from .clusters import create_cluster_message from .detection import create_detection_message, create_line_detection_message from .image import create_image_message from .keypoints import create_hand_keypoints_message, create_keypoints_message from .map import create_map_message -from .misc import create_age_gender_message, create_multi_classification_message +from .misc import create_age_gender_message from .segmentation import create_sam_message, create_segmentation_message from .tracked_features import create_tracked_features_message diff --git a/depthai_nodes/ml/messages/creators/classification.py b/depthai_nodes/ml/messages/creators/classification.py index 42ed2633..49b7dd64 100644 --- a/depthai_nodes/ml/messages/creators/classification.py +++ b/depthai_nodes/ml/messages/creators/classification.py @@ -2,7 +2,7 @@ import numpy as np -from ...messages import Classifications +from ...messages import Classifications, CompositeMessage def create_classification_message( @@ -82,3 +82,164 @@ def create_classification_message( classification_msg.scores = scores.tolist() return classification_msg + + +def create_multi_classification_message( + classification_attributes: List[str], + classification_scores: Union[np.ndarray, List[List[float]]], + classification_labels: List[List[str]], +) -> CompositeMessage: + """Create a DepthAI message for multi-classification. + + @param classification_attributes: List of attributes being classified. + @type classification_attributes: List[str] + @param classification_scores: A 2D array or list of classification scores for each + attribute. + @type classification_scores: Union[np.ndarray, List[List[float]]] + @param classification_labels: A 2D list of class labels for each classification + attribute. + @type classification_labels: List[List[str]] + @return: MultiClassification message containing a dictionary of classification + attributes and their respective Classifications. + @rtype: dai.Buffer + @raise ValueError: If number of attributes is not same as number of score-label + pairs. + @raise ValueError: If number of scores is not same as number of labels for each + attribute. + @raise ValueError: If each class score not in the range [0, 1]. + @raise ValueError: If each class score not a probability distribution that sums to + 1. + """ + + if len(classification_attributes) != len(classification_scores) or len( + classification_attributes + ) != len(classification_labels): + raise ValueError( + f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels." + ) + + multi_class_dict = {} + for attribute, scores, labels in zip( + classification_attributes, classification_scores, classification_labels + ): + if len(scores) != len(labels): + raise ValueError( + f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}." + ) + multi_class_dict[attribute] = create_classification_message(labels, scores) + + multi_classification_message = CompositeMessage() + multi_classification_message.setData(multi_class_dict) + + return multi_classification_message + + +def create_classification_sequence_message( + classes: List[str], + scores: Union[np.ndarray, List], + ignored_indexes: List[int] = None, + remove_duplicates: bool = False, + concatenate_text: bool = False, +) -> Classifications: + """Creates a message for a multi-class sequence. The 'scores' array is a sequence of + probabilities for each class at each position in the sequence. The message contains + the class names and their respective scores, ordered according to the sequence. + + @param classes: A list of class names, with length 'n_classes'. + @type classes: List + @param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. + @type scores: np.ndarray + @param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) + @type ignored_indexes: List[int] + @param remove_duplicates: If True, removes consecutive duplicates from the sequence. + @type remove_duplicates: bool + @param concatenate_text: If True, concatenates consecutive words based on the space character. + @type concatenate_text: bool + @return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores. + @rtype: Classifications + @raises ValueError: If 'classes' is not a list of strings. + @raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). + @raises ValueError: If the number of classes does not match the number of columns in 'scores'. + @raises ValueError: If any score is not in the range [0, 1]. + @raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. + @raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. + """ + + if not isinstance(classes, List): + raise ValueError(f"Classes should be a list, got {type(classes)}.") + + if isinstance(scores, List): + scores = np.array(scores) + + if len(scores.shape) != 2: + raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") + + if scores.shape[1] != len(classes): + raise ValueError( + f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." + ) + + if np.any(scores < 0) or np.any(scores > 1): + raise ValueError("Scores should be in the range [0, 1].") + + if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): + raise ValueError("Each row of scores should sum to 1.") + + if ignored_indexes is not None: + if not isinstance(ignored_indexes, List): + raise ValueError( + f"Ignored indexes should be a list, got {type(ignored_indexes)}." + ) + if not all(isinstance(index, int) for index in ignored_indexes): + raise ValueError("Ignored indexes should be integers.") + if np.any(np.array(ignored_indexes) < 0) or np.any( + np.array(ignored_indexes) >= len(classes) + ): + raise ValueError( + "Ignored indexes should be integers in the range [0, num_classes -1]." + ) + + selection = np.ones(len(scores), dtype=bool) + indexes = np.argmax(scores, axis=1) + + if remove_duplicates: + selection[1:] = indexes[1:] != indexes[:-1] + + if ignored_indexes is not None: + selection &= np.array([index not in ignored_indexes for index in indexes]) + + class_list = [classes[i] for i in indexes[selection]] + score_list = np.max(scores, axis=1)[selection] + + if ( + concatenate_text + and len(class_list) > 1 + and all(len(word) <= 1 for word in class_list) + ): + concatenated_scores = [] + concatenated_words = "".join(class_list).split() + cumsumlist = np.cumsum([len(word) for word in concatenated_words]) + + start_index = 0 + for num_spaces, end_index in enumerate(cumsumlist): + word_scores = score_list[start_index + num_spaces : end_index + num_spaces] + concatenated_scores.append(np.mean(word_scores)) + start_index = end_index + + class_list = concatenated_words + score_list = np.array(concatenated_scores) + + elif ( + concatenate_text + and len(class_list) > 1 + and any(len(word) >= 2 for word in class_list) + ): + class_list = [" ".join(class_list)] + score_list = np.mean(score_list) + + classification_msg = Classifications() + + classification_msg.classes = class_list + classification_msg.scores = score_list.tolist() + + return classification_msg diff --git a/depthai_nodes/ml/messages/creators/classification_sequence.py b/depthai_nodes/ml/messages/creators/classification_sequence.py deleted file mode 100644 index 216905f7..00000000 --- a/depthai_nodes/ml/messages/creators/classification_sequence.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import List, Union - -import numpy as np - -from .. import Classifications - - -def create_classification_sequence_message( - classes: List[str], - scores: Union[np.ndarray, List], - ignored_indexes: List[int] = None, - remove_duplicates: bool = False, - concatenate_text: bool = False, -) -> Classifications: - """Creates a message for a multi-class sequence. The 'scores' array is a sequence of - probabilities for each class at each position in the sequence. The message contains - the class names and their respective scores, ordered according to the sequence. - - @param classes: A list of class names, with length 'n_classes'. - @type classes: List - @param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. - @type scores: np.ndarray - @param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) - @type ignored_indexes: List[int] - @param remove_duplicates: If True, removes consecutive duplicates from the sequence. - @type remove_duplicates: bool - @param concatenate_text: If True, concatenates consecutive words based on the space character. - @type concatenate_text: bool - @return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores. - @rtype: Classifications - @raises ValueError: If 'classes' is not a list of strings. - @raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). - @raises ValueError: If the number of classes does not match the number of columns in 'scores'. - @raises ValueError: If any score is not in the range [0, 1]. - @raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. - @raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. - """ - - if not isinstance(classes, List): - raise ValueError(f"Classes should be a list, got {type(classes)}.") - - if isinstance(scores, List): - scores = np.array(scores) - - if len(scores.shape) != 2: - raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") - - if scores.shape[1] != len(classes): - raise ValueError( - f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." - ) - - if np.any(scores < 0) or np.any(scores > 1): - raise ValueError("Scores should be in the range [0, 1].") - - if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): - raise ValueError("Each row of scores should sum to 1.") - - if ignored_indexes is not None: - if not isinstance(ignored_indexes, List): - raise ValueError( - f"Ignored indexes should be a list, got {type(ignored_indexes)}." - ) - if not all(isinstance(index, int) for index in ignored_indexes): - raise ValueError("Ignored indexes should be integers.") - if np.any(np.array(ignored_indexes) < 0) or np.any( - np.array(ignored_indexes) >= len(classes) - ): - raise ValueError( - "Ignored indexes should be integers in the range [0, num_classes -1]." - ) - - selection = np.ones(len(scores), dtype=bool) - indexes = np.argmax(scores, axis=1) - - if remove_duplicates: - selection[1:] = indexes[1:] != indexes[:-1] - - if ignored_indexes is not None: - selection &= np.array([index not in ignored_indexes for index in indexes]) - - class_list = [classes[i] for i in indexes[selection]] - score_list = np.max(scores, axis=1)[selection] - - if ( - concatenate_text - and len(class_list) > 1 - and all(len(word) <= 1 for word in class_list) - ): - concatenated_scores = [] - concatenated_words = "".join(class_list).split() - cumsumlist = np.cumsum([len(word) for word in concatenated_words]) - - start_index = 0 - for num_spaces, end_index in enumerate(cumsumlist): - word_scores = score_list[start_index + num_spaces : end_index + num_spaces] - concatenated_scores.append(np.mean(word_scores)) - start_index = end_index - - class_list = concatenated_words - score_list = np.array(concatenated_scores) - - elif ( - concatenate_text - and len(class_list) > 1 - and any(len(word) >= 2 for word in class_list) - ): - class_list = [" ".join(class_list)] - score_list = np.mean(score_list) - - classification_msg = Classifications() - - classification_msg.classes = class_list - classification_msg.scores = score_list.tolist() - - return classification_msg diff --git a/depthai_nodes/ml/messages/creators/misc.py b/depthai_nodes/ml/messages/creators/misc.py index b2999eb1..7ef091d5 100644 --- a/depthai_nodes/ml/messages/creators/misc.py +++ b/depthai_nodes/ml/messages/creators/misc.py @@ -1,12 +1,9 @@ -from typing import List, Union +from typing import List -import numpy as np +from ...messages import Classifications, CompositeMessage -from ...messages import AgeGender, Classifications, CompositeMessage -from ...messages.creators import create_classification_message - -def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender: +def create_age_gender_message(age: float, gender_prob: List[float]) -> CompositeMessage: """Create a DepthAI message for the age and gender probability. @param age: Detected person age (must be multiplied by 100 to get years). @@ -42,61 +39,11 @@ def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender f"Gender_prob list must contain probabilities and sum to 1, got sum {sum(gender_prob)}." ) - age_gender_message = AgeGender() - age_gender_message.age = age gender = Classifications() gender.classes = ["female", "male"] gender.scores = gender_prob - age_gender_message.gender = gender - - return age_gender_message - - -def create_multi_classification_message( - classification_attributes: List[str], - classification_scores: Union[np.ndarray, List[List[float]]], - classification_labels: List[List[str]], -) -> CompositeMessage: - """Create a DepthAI message for multi-classification. - - @param classification_attributes: List of attributes being classified. - @type classification_attributes: List[str] - @param classification_scores: A 2D array or list of classification scores for each - attribute. - @type classification_scores: Union[np.ndarray, List[List[float]]] - @param classification_labels: A 2D list of class labels for each classification - attribute. - @type classification_labels: List[List[str]] - @return: MultiClassification message containing a dictionary of classification - attributes and their respective Classifications. - @rtype: dai.Buffer - @raise ValueError: If number of attributes is not same as number of score-label - pairs. - @raise ValueError: If number of scores is not same as number of labels for each - attribute. - @raise ValueError: If each class score not in the range [0, 1]. - @raise ValueError: If each class score not a probability distribution that sums to - 1. - """ - if len(classification_attributes) != len(classification_scores) or len( - classification_attributes - ) != len(classification_labels): - raise ValueError( - f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels." - ) + age_gender_message = CompositeMessage() + age_gender_message.setData({"age": age, "gender": gender}) - multi_class_dict = {} - for attribute, scores, labels in zip( - classification_attributes, classification_scores, classification_labels - ): - if len(scores) != len(labels): - raise ValueError( - f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}." - ) - multi_class_dict[attribute] = create_classification_message(labels, scores) - - multi_classification_message = CompositeMessage() - multi_classification_message.setData(multi_class_dict) - - return multi_classification_message + return age_gender_message diff --git a/depthai_nodes/ml/parsers/__init__.py b/depthai_nodes/ml/parsers/__init__.py index d72b6c94..61982eca 100644 --- a/depthai_nodes/ml/parsers/__init__.py +++ b/depthai_nodes/ml/parsers/__init__.py @@ -1,5 +1,5 @@ from .age_gender import AgeGenderParser -from .classification import ClassificationParser +from .classification import ClassificationParser, MultiClassificationParser from .fastsam import FastSAMParser from .hrnet import HRNetParser from .image_output import ImageOutputParser @@ -14,7 +14,6 @@ from .scrfd import SCRFDParser from .segmentation import SegmentationParser from .superanimal_landmarker import SuperAnimalParser -from .vehicle_attributes import MultiClassificationParser from .xfeat import XFeatParser from .yolo import YOLOExtendedParser from .yunet import YuNetParser diff --git a/depthai_nodes/ml/parsers/classification.py b/depthai_nodes/ml/parsers/classification.py index 6295e7c6..77716739 100644 --- a/depthai_nodes/ml/parsers/classification.py +++ b/depthai_nodes/ml/parsers/classification.py @@ -3,7 +3,10 @@ import depthai as dai import numpy as np -from ..messages.creators import create_classification_message +from ..messages.creators import ( + create_classification_message, + create_multi_classification_message, +) class ClassificationParser(dai.node.ThreadedHostNode): @@ -95,3 +98,59 @@ def run(self): msg.setTimestamp(output.getTimestamp()) self.out.send(msg) + + +class MultiClassificationParser(dai.node.ThreadedHostNode): + """Postprocessing logic for Multiple Classification model. + + Attributes + ---------- + input : Node.Input + Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. + out : Node.Output + Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. + classification_attributes : List[str] + List of attributes to be classified. + classification_labels : List[List[str]] + List of class labels for each attribute in `classification_attributes` + + Output Message/s + ---------------- + **Type**: CompositeMessage + + **Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values. + """ + + def __init__( + self, + classification_attributes: List[str], + classification_labels: List[List[str]], + ): + """Initializes the MultipleClassificationParser node.""" + dai.node.ThreadedHostNode.__init__(self) + self.out = self.createOutput() + self.input = self.createInput() + self.classification_attributes: List[str] = classification_attributes + self.classification_labels: List[List[str]] = classification_labels + + def run(self): + while self.isRunning(): + try: + output: dai.NNData = self.input.get() + except dai.MessageQueue.QueueException: + break + + layer_names = output.getAllLayerNames() + + scores = [] + for layer_name in layer_names: + scores.append( + output.getTensor(layer_name, dequantize=True).flatten().tolist() + ) + + multi_classification_message = create_multi_classification_message( + self.classification_attributes, scores, self.classification_labels + ) + multi_classification_message.setTimestamp(output.getTimestamp()) + + self.out.send(multi_classification_message) diff --git a/depthai_nodes/ml/parsers/ppdet.py b/depthai_nodes/ml/parsers/ppdet.py index 8a74fcbe..23d81cc5 100644 --- a/depthai_nodes/ml/parsers/ppdet.py +++ b/depthai_nodes/ml/parsers/ppdet.py @@ -1,7 +1,7 @@ import depthai as dai from ..messages.creators import create_detection_message -from .utils.ppdet import corners2xyxy, parse_paddle_detection_outputs +from .utils import corners2xyxy, parse_paddle_detection_outputs class PPTextDetectionParser(dai.node.ThreadedHostNode): diff --git a/depthai_nodes/ml/parsers/vehicle_attributes.py b/depthai_nodes/ml/parsers/vehicle_attributes.py deleted file mode 100644 index 5e5d1c17..00000000 --- a/depthai_nodes/ml/parsers/vehicle_attributes.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List - -import depthai as dai - -from ..messages.creators import create_multi_classification_message - - -class MultiClassificationParser(dai.node.ThreadedHostNode): - """Postprocessing logic for Multiple Classification model. - - Attributes - ---------- - input : Node.Input - Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. - out : Node.Output - Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. - classification_attributes : List[str] - List of attributes to be classified. - classification_labels : List[List[str]] - List of class labels for each attribute in `classification_attributes` - - Output Message/s - ---------------- - **Type**: CompositeMessage - - **Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values. - """ - - def __init__( - self, - classification_attributes: List[str], - classification_labels: List[List[str]], - ): - """Initializes the MultipleClassificationParser node.""" - dai.node.ThreadedHostNode.__init__(self) - self.out = self.createOutput() - self.input = self.createInput() - self.classification_attributes: List[str] = classification_attributes - self.classification_labels: List[List[str]] = classification_labels - - def run(self): - while self.isRunning(): - try: - output: dai.NNData = self.input.get() - except dai.MessageQueue.QueueException: - break - - layer_names = output.getAllLayerNames() - - scores = [] - for layer_name in layer_names: - scores.append( - output.getTensor(layer_name, dequantize=True).flatten().tolist() - ) - - multi_classification_message = create_multi_classification_message( - self.classification_attributes, scores, self.classification_labels - ) - multi_classification_message.setTimestamp(output.getTimestamp()) - - self.out.send(multi_classification_message) diff --git a/tests/unittests/test_creators/test_classification_sequence.py b/tests/unittests/test_creators/test_classification_sequence.py index f4017e1b..a1f85e5a 100644 --- a/tests/unittests/test_creators/test_classification_sequence.py +++ b/tests/unittests/test_creators/test_classification_sequence.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from depthai_nodes.ml.messages.creators.classification_sequence import ( +from depthai_nodes.ml.messages.creators import ( create_classification_sequence_message, ) diff --git a/tests/unittests/test_creators/test_misc.py b/tests/unittests/test_creators/test_misc.py index dced1045..e49c1539 100644 --- a/tests/unittests/test_creators/test_misc.py +++ b/tests/unittests/test_creators/test_misc.py @@ -1,8 +1,8 @@ import numpy as np import pytest -from depthai_nodes.ml.messages import AgeGender -from depthai_nodes.ml.messages.creators.misc import create_age_gender_message +from depthai_nodes.ml.messages import CompositeMessage +from depthai_nodes.ml.messages.creators import create_age_gender_message def test_wrong_age(): @@ -45,10 +45,13 @@ def test_correct_types(): gender = [0.35, 0.65] message = create_age_gender_message(age, gender) - assert isinstance(message, AgeGender) - assert message.age == age - assert message.gender.classes == ["female", "male"] - assert np.all(np.isclose(message.gender.scores, gender)) + assert isinstance(message, CompositeMessage) + result = message.getData() + assert "age" in result + assert "gender" in result + assert result["age"] == age + assert result["gender"].classes == ["female", "male"] + assert np.all(np.isclose(result["gender"].scores, gender)) if __name__ == "__main__": diff --git a/tests/unittests/test_creators/test_vehicle_attributes.py b/tests/unittests/test_creators/test_vehicle_attributes.py index 07363824..a341bfc7 100644 --- a/tests/unittests/test_creators/test_vehicle_attributes.py +++ b/tests/unittests/test_creators/test_vehicle_attributes.py @@ -1,7 +1,7 @@ import pytest from depthai_nodes.ml.messages import Classifications, CompositeMessage -from depthai_nodes.ml.messages.creators.misc import create_multi_classification_message +from depthai_nodes.ml.messages.creators import create_multi_classification_message def test_incorect_lengths(): From c98abfd24b74d3dc3c4443cedf28e2b79c9110f7 Mon Sep 17 00:00:00 2001 From: aljazkonec1 Date: Tue, 17 Sep 2024 08:34:41 +0200 Subject: [PATCH 2/3] remove AgeGender Message. --- depthai_nodes/ml/messages/__init__.py | 2 -- depthai_nodes/ml/messages/misc.py | 34 --------------------------- 2 files changed, 36 deletions(-) delete mode 100644 depthai_nodes/ml/messages/misc.py diff --git a/depthai_nodes/ml/messages/__init__.py b/depthai_nodes/ml/messages/__init__.py index 25ae687a..835a1c42 100644 --- a/depthai_nodes/ml/messages/__init__.py +++ b/depthai_nodes/ml/messages/__init__.py @@ -8,7 +8,6 @@ from .keypoints import HandKeypoints, Keypoints from .lines import Line, Lines from .map import Map2D -from .misc import AgeGender from .segmentation import SegmentationMasks __all__ = [ @@ -20,7 +19,6 @@ "Lines", "Classifications", "SegmentationMasks", - "AgeGender", "Map2D", "Clusters", "Cluster", diff --git a/depthai_nodes/ml/messages/misc.py b/depthai_nodes/ml/messages/misc.py deleted file mode 100644 index c3aaf838..00000000 --- a/depthai_nodes/ml/messages/misc.py +++ /dev/null @@ -1,34 +0,0 @@ -import depthai as dai - -from ..messages import Classifications - - -class AgeGender(dai.Buffer): - def __init__(self): - super().__init__() - self._age: float = None - self._gender = Classifications() - - @property - def age(self) -> float: - return self._age - - @age.setter - def age(self, value: float): - if not isinstance(value, float): - raise TypeError( - f"start_point must be of type float, instead got {type(value)}." - ) - self._age = value - - @property - def gender(self) -> Classifications: - return self._gender - - @gender.setter - def gender(self, value: Classifications): - if not isinstance(value, Classifications): - raise TypeError( - f"gender must be of type Classifications, instead got {type(value)}." - ) - self._gender = value From 13d7d4a62c25256ad7d4397ab7ccef7a5492927c Mon Sep 17 00:00:00 2001 From: aljazkonec1 Date: Tue, 17 Sep 2024 09:14:10 +0200 Subject: [PATCH 3/3] Refactor examples. --- examples/visualization/classification.py | 6 ++++-- examples/visualization/messages.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/visualization/classification.py b/examples/visualization/classification.py index 90c3bb04..80e70143 100644 --- a/examples/visualization/classification.py +++ b/examples/visualization/classification.py @@ -1,7 +1,7 @@ import cv2 import depthai as dai -from depthai_nodes.ml.messages import AgeGender, Classifications +from depthai_nodes.ml.messages import Classifications, CompositeMessage from .messages import parse_classification_message, parser_age_gender_message @@ -33,7 +33,9 @@ def visualize_classification( return False -def visualize_age_gender(frame: dai.ImgFrame, message: AgeGender, extraParams: dict): +def visualize_age_gender( + frame: dai.ImgFrame, message: CompositeMessage, extraParams: dict +): """Visualizes the age and predicted gender on the frame.""" if frame.shape[0] < 128: frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) diff --git a/examples/visualization/messages.py b/examples/visualization/messages.py index facd212f..37a5e882 100644 --- a/examples/visualization/messages.py +++ b/examples/visualization/messages.py @@ -1,9 +1,9 @@ import depthai as dai from depthai_nodes.ml.messages import ( - AgeGender, Classifications, Clusters, + CompositeMessage, ImgDetectionsExtended, Keypoints, Lines, @@ -49,11 +49,11 @@ def parse_image_message(message: dai.ImgFrame): return image -def parser_age_gender_message(message: AgeGender): +def parser_age_gender_message(message: CompositeMessage): """Parses the age-gender message and return the age and scores for all genders.""" - - age = message.age - gender = message.gender + message = message.getData() + age = message["age"] + gender = message["gender"] gender_scores = gender.scores gender_classes = gender.classes