-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PaddleOCR Parser and message #62
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import List, Union | ||
|
||
import numpy as np | ||
|
||
from .. import Classifications | ||
|
||
|
||
def create_classification_sequence_message( | ||
classes: List, | ||
scores: Union[np.ndarray, List], | ||
remove_duplicates: bool = False, | ||
ignored_indexes: List[int] = None, | ||
) -> Classifications: | ||
"""Creates a message for a multi-class sequence. The 'scores' array is a sequence of | ||
probabilities for each class at each position in the sequence. The message contains | ||
the class names and their respective scores, ordered according to the sequence. | ||
|
||
@param classes: A list of class names, with length 'n_classes'. | ||
@type classes: List | ||
@param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. | ||
@type scores: np.ndarray | ||
@param remove_duplicates: If True, removes consecutive duplicates from the sequence. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe explain this duplicate removal better as it not really clear what is removed and why |
||
@type remove_duplicates: bool | ||
@param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) | ||
@type ignored_indexes: List[int] | ||
|
||
@return: A message with attributes `classes` and `scores`, both ordered by the sequence. | ||
@rtype: Classifications | ||
|
||
@raises ValueError: If 'classes' is not a list of strings. | ||
@raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). | ||
@raises ValueError: If the number of classes does not match the number of columns in 'scores'. | ||
@raises ValueError: If any score is not in the range [0, 1]. | ||
@raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. | ||
@raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. | ||
""" | ||
|
||
if not isinstance(classes, List): | ||
raise ValueError(f"Classes should be a list, got {type(classes)}.") | ||
|
||
if isinstance(scores, List): | ||
scores = np.array(scores) | ||
|
||
if len(scores.shape) != 2: | ||
raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") | ||
|
||
if scores.shape[1] != len(classes): | ||
raise ValueError( | ||
f"Number of labels and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." | ||
) | ||
|
||
if np.any(scores < 0) or np.any(scores > 1): | ||
raise ValueError("Scores should be in the range [0, 1].") | ||
|
||
if np.any(np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): | ||
raise ValueError("Each row of scores should sum to 1.") | ||
|
||
if ignored_indexes is not None: | ||
if not isinstance(ignored_indexes, List): | ||
raise ValueError( | ||
f"Ignored indexes should be a list, got {type(ignored_indexes)}." | ||
) | ||
if np.any(np.array(ignored_indexes) < 0) or np.any( | ||
np.array(ignored_indexes) >= len(classes) | ||
): | ||
raise ValueError( | ||
"Ignored indexes should be integers in the range [0, num_classes -1]." | ||
) | ||
|
||
selection = np.ones(len(scores), dtype=bool) | ||
|
||
indexes = np.argmax(scores, axis=1) | ||
|
||
if remove_duplicates: | ||
selection[1:] = indexes[1:] != indexes[:-1] | ||
|
||
if ignored_indexes is not None: | ||
selection &= indexes != ignored_indexes | ||
|
||
class_list = [classes[i] for i in indexes[selection]] | ||
score_list = scores[selection].tolist() | ||
|
||
classification_msg = Classifications() | ||
|
||
classification_msg.classes = class_list | ||
classification_msg.scores = score_list | ||
|
||
return classification_msg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import List | ||
|
||
import depthai as dai | ||
import numpy as np | ||
|
||
from ..messages.creators import create_classification_sequence_message | ||
from .classification import ClassificationParser | ||
|
||
|
||
class PaddleOCRParser(ClassificationParser): | ||
"""""" | ||
|
||
def __init__( | ||
self, | ||
classes: List[str] = None, | ||
is_softmax: bool = True, | ||
remove_duplicates: bool = True, | ||
ignored_indexes: List[int] = None, | ||
): | ||
"""Initializes the PaddleOCR Parser node. | ||
|
||
@param classes: List of class names to be | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some descriptions are missing here |
||
""" | ||
super().__init__(classes, is_softmax) | ||
self.out = self.createOutput() | ||
self.input = self.createInput() | ||
self.remove_duplicates = remove_duplicates | ||
self.ignored_indexes = [0] if ignored_indexes is None else ignored_indexes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be more clean to set the default |
||
|
||
def setRemoveDuplicates(self, remove_duplicates: bool): | ||
"""Sets the remove_duplicates flag for the classification sequence model. | ||
|
||
@param remove_duplicates: If True, removes consecutive duplicates from the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @type is missing in the docstring (also below) |
||
sequence. | ||
""" | ||
self.remove_duplicates = remove_duplicates | ||
|
||
def setIgnoredIndexes(self, ignored_indexes: List[int]): | ||
"""Sets the ignored_indexes for the classification sequence model. | ||
|
||
@param ignored_indexes: A list of indexes to ignore during classification | ||
generation. | ||
""" | ||
self.ignored_indexes = ignored_indexes | ||
|
||
def run(self): | ||
while self.isRunning(): | ||
try: | ||
output: dai.NNData = self.input.get() | ||
|
||
except dai.MessageQueue.QueueException: | ||
break | ||
|
||
output_layer_names = output.getAllLayerNames() | ||
if len(output_layer_names) != 1: | ||
raise ValueError(f"Expected 1 output layer, got {len(output_layer_names)}.") | ||
|
||
if self.n_classes == 0: | ||
raise ValueError("Classes must be provided for classification.") | ||
|
||
scores = output.getTensor(output_layer_names[0], dequantize=True).astype( | ||
np.float32 | ||
) | ||
|
||
if len(scores.shape) != 3: | ||
raise ValueError(f"Scores should be a 3D array, got {scores.shape}.") | ||
|
||
if scores.shape[0] == 1: | ||
scores = scores[0] | ||
elif scores.shape[2] == 1: | ||
scores = scores[:, :, 0] | ||
else: | ||
raise ValueError( | ||
"Scores should be a 3D array of shape (1, sequence_length, n_classes) or (sequence_length, n_classes, 1)." | ||
) | ||
|
||
if not self.is_softmax: | ||
scores = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True) | ||
|
||
msg = create_classification_sequence_message( | ||
classes=self.classes, | ||
scores=scores, | ||
remove_duplicates=self.remove_duplicates, | ||
ignored_indexes=self.ignored_indexes, | ||
) | ||
msg.setTimestamp(output.getTimestamp()) | ||
|
||
self.out.send(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this needs a separate creator function. Could we adjust the
create_classification_message
to also support the new features (remove_duplicates
andignored_indexes
)?