Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dai.ImgTransformations to messages and improvements to Text detection parser #142

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ class Classifications(dai.Buffer):
A list of classes.
scores : NDArray[np.float32]
Corresponding probability scores.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
"""Initializes the Classifications object."""
dai.Buffer.__init__(self)
self._classes: List[str] = []
self._scores: NDArray[np.float32] = np.array([])
self._transformation: dai.ImgTransformation = None

@property
def classes(self) -> List:
Expand Down Expand Up @@ -91,3 +94,26 @@ def top_score(self) -> float:
@rtype: float
"""
return self._scores[0]

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the transformation property and setter are the same for every message, we could also make a message parent class that would define them once instead of copy-pasting them to all messages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I also thought about it. I asked depthai team and when this will get ported, all messages will inherit from dai.Buffer and will have similar structure as dai.ImgDetections currently has. Thats why I left it as is.


@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ class Clusters(dai.Buffer):
----------
clusters : List[Cluster]
List of clusters.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
"""Initializes the Clusters object."""
super().__init__()
self._clusters: List[Cluster] = []
self._transformation: dai.ImgTransformation = None

@property
def clusters(self) -> List[Cluster]:
Expand All @@ -103,3 +106,26 @@ def clusters(self, value: List[Cluster]):
if not all(isinstance(cluster, Cluster) for cluster in value):
raise ValueError("Clusters must be a list of Cluster objects.")
self._clusters = value

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,16 @@ class ImgDetectionsExtended(dai.Buffer):
Image detections with keypoints.
masks: np.ndarray
The segmentation masks of the image. All masks are stored in a single numpy array.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self) -> None:
"""Initializes the ImgDetectionsExtended object."""
super().__init__()
self._detections: List[ImgDetectionExtended] = []
self._masks: SegmentationMask = SegmentationMask()
self._transformation: dai.ImgTransformation = None

@property
def detections(self) -> List[ImgDetectionExtended]:
Expand Down Expand Up @@ -226,3 +229,26 @@ def masks(self, value: NDArray[np.int16]):
masks_msg = SegmentationMask()
masks_msg.mask = value
self._masks = masks_msg

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ class Keypoints(dai.Buffer):
----------
keypoints: List[Keypoint]
List of Keypoint objects, each representing a keypoint.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
"""Initializes the Keypoints object."""
super().__init__()
self._keypoints: List[Keypoint] = []
self._transformation: dai.ImgTransformation = None

@property
def keypoints(self) -> List[Keypoint]:
Expand All @@ -157,3 +160,26 @@ def keypoints(self, value: List[Keypoint]):
if not all(isinstance(item, Keypoint) for item in value):
raise ValueError("keypoints must be a list of Keypoint objects.")
self._keypoints = value

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,15 @@ class Lines(dai.Buffer):
----------
lines : List[Line]
List of detected lines.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
"""Initializes the Lines object."""
super().__init__()
self._lines: List[Line] = []
self._transformation: dai.ImgTransformation = None

@property
def lines(self) -> List[Line]:
Expand All @@ -130,3 +133,26 @@ def lines(self, value: List[Line]):
if not all(isinstance(item, Line) for item in value):
raise ValueError("Lines must be a list of Line objects.")
self._lines = value

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class Map2D(dai.Buffer):
2D Map width.
height : int
2D Map height.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
Expand All @@ -22,6 +24,7 @@ def __init__(self):
self._map: NDArray[np.float32] = np.array([])
self._width: int = None
self._height: int = None
self._transformation: dai.ImgTransformation = None

@property
def map(self) -> NDArray[np.float32]:
Expand Down Expand Up @@ -71,3 +74,26 @@ def height(self) -> int:
@rtype: int
"""
return self._height

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class Predictions(dai.Buffer):
----------
predictions : List[Prediction]
List of predictions.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
"""Initializes the Predictions object."""
super().__init__()
self._predictions: List[Prediction] = []
self._transformation: dai.ImgTransformation = None

@property
def predictions(self) -> List[Prediction]:
Expand Down Expand Up @@ -89,3 +92,26 @@ def prediction(self) -> float:
@rtype: float
"""
return self._predictions[0].prediction

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ class SegmentationMask(dai.Buffer):
----------
mask: NDArray[np.int16]
Segmentation mask.
transformation : dai.ImgTransformation
Image transformation object.
"""

def __init__(self):
"""Initializes the SegmentationMask object."""
super().__init__()
self._mask: NDArray[np.int16] = np.array([])
self._transformation: dai.ImgTransformation = None

@property
def mask(self) -> NDArray[np.int16]:
Expand Down Expand Up @@ -48,3 +51,26 @@ def mask(self, value: NDArray[np.int16]):
if np.any((value < -1)):
raise ValueError("Mask must be an array of integers larger or equal to -1.")
self._mask = value

@property
def transformation(self) -> dai.ImgTransformation:
"""Returns the Image Transformation object.

@return: The Image Transformation object.
@rtype: dai.ImgTransformation
"""
return self._transformation

@transformation.setter
def transformation(self, value: dai.ImgTransformation):
"""Sets the Image Transformation object.

@param value: The Image Transformation object.
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value
1 change: 1 addition & 0 deletions depthai_nodes/ml/parsers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def run(self):
scores = softmax(scores)

msg = create_classification_message(self.classes, scores)
msg.transformation = output.getTransformation()
msg.setTimestamp(output.getTimestamp())

self.out.send(msg)
1 change: 1 addition & 0 deletions depthai_nodes/ml/parsers/classification_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def run(self):
ignored_indexes=self.ignored_indexes,
concatenate_classes=self.concatenate_classes,
)
msg.transformation = output.getTransformation()
msg.setTimestamp(output.getTimestamp())

self.out.send(msg)
2 changes: 1 addition & 1 deletion depthai_nodes/ml/parsers/fastsam.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,6 @@ def run(self):
results_masks = merge_masks(results_masks)

segmentation_message = create_segmentation_message(results_masks)
segmentation_message.transformation = output.getTransformation()
segmentation_message.setTimestamp(output.getTimestamp())

self.out.send(segmentation_message)
1 change: 1 addition & 0 deletions depthai_nodes/ml/parsers/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ def run(self):
confidence_threshold=self.score_threshold,
)
keypoints_message.setTimestamp(output.getTimestamp())
keypoints_message.transformation = output.getTransformation()

self.out.send(keypoints_message)
1 change: 1 addition & 0 deletions depthai_nodes/ml/parsers/image_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ def run(self):
is_bgr=self.output_is_bgr,
)
image_message.setTimestamp(output.getTimestamp())
image_message.transformation = output.getTransformation()

self.out.send(image_message)
1 change: 1 addition & 0 deletions depthai_nodes/ml/parsers/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,6 @@ def run(self):

msg = create_keypoints_message(keypoints)
msg.setTimestamp(output.getTimestamp())
msg.transformation = output.getTransformation()

self.out.send(msg)
Loading
Loading