Skip to content

Commit

Permalink
Merge pull request #68 from luxonis/update-parser-locations
Browse files Browse the repository at this point in the history
Refactor Code
  • Loading branch information
aljazkonec1 authored Sep 17, 2024
2 parents ceac1c8 + 13d7d4a commit a6d34c1
Show file tree
Hide file tree
Showing 15 changed files with 256 additions and 295 deletions.
2 changes: 0 additions & 2 deletions depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .keypoints import HandKeypoints, Keypoints
from .lines import Line, Lines
from .map import Map2D
from .misc import AgeGender
from .segmentation import SegmentationMasks

__all__ = [
Expand All @@ -20,7 +19,6 @@
"Lines",
"Classifications",
"SegmentationMasks",
"AgeGender",
"Map2D",
"Clusters",
"Cluster",
Expand Down
9 changes: 6 additions & 3 deletions depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from .classification import create_classification_message
from .classification_sequence import create_classification_sequence_message
from .classification import (
create_classification_message,
create_classification_sequence_message,
create_multi_classification_message,
)
from .clusters import create_cluster_message
from .detection import create_detection_message, create_line_detection_message
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .map import create_map_message
from .misc import create_age_gender_message, create_multi_classification_message
from .misc import create_age_gender_message
from .segmentation import create_sam_message, create_segmentation_message
from .tracked_features import create_tracked_features_message

Expand Down
163 changes: 162 additions & 1 deletion depthai_nodes/ml/messages/creators/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ...messages import Classifications
from ...messages import Classifications, CompositeMessage


def create_classification_message(
Expand Down Expand Up @@ -82,3 +82,164 @@ def create_classification_message(
classification_msg.scores = scores.tolist()

return classification_msg


def create_multi_classification_message(
classification_attributes: List[str],
classification_scores: Union[np.ndarray, List[List[float]]],
classification_labels: List[List[str]],
) -> CompositeMessage:
"""Create a DepthAI message for multi-classification.
@param classification_attributes: List of attributes being classified.
@type classification_attributes: List[str]
@param classification_scores: A 2D array or list of classification scores for each
attribute.
@type classification_scores: Union[np.ndarray, List[List[float]]]
@param classification_labels: A 2D list of class labels for each classification
attribute.
@type classification_labels: List[List[str]]
@return: MultiClassification message containing a dictionary of classification
attributes and their respective Classifications.
@rtype: dai.Buffer
@raise ValueError: If number of attributes is not same as number of score-label
pairs.
@raise ValueError: If number of scores is not same as number of labels for each
attribute.
@raise ValueError: If each class score not in the range [0, 1].
@raise ValueError: If each class score not a probability distribution that sums to
1.
"""

if len(classification_attributes) != len(classification_scores) or len(
classification_attributes
) != len(classification_labels):
raise ValueError(
f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels."
)

multi_class_dict = {}
for attribute, scores, labels in zip(
classification_attributes, classification_scores, classification_labels
):
if len(scores) != len(labels):
raise ValueError(
f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}."
)
multi_class_dict[attribute] = create_classification_message(labels, scores)

multi_classification_message = CompositeMessage()
multi_classification_message.setData(multi_class_dict)

return multi_classification_message


def create_classification_sequence_message(
classes: List[str],
scores: Union[np.ndarray, List],
ignored_indexes: List[int] = None,
remove_duplicates: bool = False,
concatenate_text: bool = False,
) -> Classifications:
"""Creates a message for a multi-class sequence. The 'scores' array is a sequence of
probabilities for each class at each position in the sequence. The message contains
the class names and their respective scores, ordered according to the sequence.
@param classes: A list of class names, with length 'n_classes'.
@type classes: List
@param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes.
@type scores: np.ndarray
@param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class)
@type ignored_indexes: List[int]
@param remove_duplicates: If True, removes consecutive duplicates from the sequence.
@type remove_duplicates: bool
@param concatenate_text: If True, concatenates consecutive words based on the space character.
@type concatenate_text: bool
@return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores.
@rtype: Classifications
@raises ValueError: If 'classes' is not a list of strings.
@raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes).
@raises ValueError: If the number of classes does not match the number of columns in 'scores'.
@raises ValueError: If any score is not in the range [0, 1].
@raises ValueError: If the probabilities in any row of 'scores' do not sum to 1.
@raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1].
"""

if not isinstance(classes, List):
raise ValueError(f"Classes should be a list, got {type(classes)}.")

if isinstance(scores, List):
scores = np.array(scores)

if len(scores.shape) != 2:
raise ValueError(f"Scores should be a 2D array, got {scores.shape}.")

if scores.shape[1] != len(classes):
raise ValueError(
f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores."
)

if np.any(scores < 0) or np.any(scores > 1):
raise ValueError("Scores should be in the range [0, 1].")

if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)):
raise ValueError("Each row of scores should sum to 1.")

if ignored_indexes is not None:
if not isinstance(ignored_indexes, List):
raise ValueError(
f"Ignored indexes should be a list, got {type(ignored_indexes)}."
)
if not all(isinstance(index, int) for index in ignored_indexes):
raise ValueError("Ignored indexes should be integers.")
if np.any(np.array(ignored_indexes) < 0) or np.any(
np.array(ignored_indexes) >= len(classes)
):
raise ValueError(
"Ignored indexes should be integers in the range [0, num_classes -1]."
)

selection = np.ones(len(scores), dtype=bool)
indexes = np.argmax(scores, axis=1)

if remove_duplicates:
selection[1:] = indexes[1:] != indexes[:-1]

if ignored_indexes is not None:
selection &= np.array([index not in ignored_indexes for index in indexes])

class_list = [classes[i] for i in indexes[selection]]
score_list = np.max(scores, axis=1)[selection]

if (
concatenate_text
and len(class_list) > 1
and all(len(word) <= 1 for word in class_list)
):
concatenated_scores = []
concatenated_words = "".join(class_list).split()
cumsumlist = np.cumsum([len(word) for word in concatenated_words])

start_index = 0
for num_spaces, end_index in enumerate(cumsumlist):
word_scores = score_list[start_index + num_spaces : end_index + num_spaces]
concatenated_scores.append(np.mean(word_scores))
start_index = end_index

class_list = concatenated_words
score_list = np.array(concatenated_scores)

elif (
concatenate_text
and len(class_list) > 1
and any(len(word) >= 2 for word in class_list)
):
class_list = [" ".join(class_list)]
score_list = np.mean(score_list)

classification_msg = Classifications()

classification_msg.classes = class_list
classification_msg.scores = score_list.tolist()

return classification_msg
116 changes: 0 additions & 116 deletions depthai_nodes/ml/messages/creators/classification_sequence.py

This file was deleted.

Loading

0 comments on commit a6d34c1

Please sign in to comment.