Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement/img detection extended #130

Merged
merged 11 commits into from
Nov 14, 2024
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ jobs:
cache: pip

- name: Install depthai
run: pip install --extra-index-url https://artifacts.luxonis.com/artifactory/luxonis-python-release-local/ depthai==3.0.0a6
run: pip install --extra-index-url https://artifacts.luxonis.com/artifactory/luxonis-python-release-local/ depthai==3.0.0a6


- name: Install package
run: pip install -e .[dev]
Expand Down
49 changes: 30 additions & 19 deletions depthai_nodes/ml/messages/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,35 @@ Here are the custom message types that we introduce in this package. They are us

**Table of Contents**

- [Classifications](#classifications)
- [Cluster](#cluster)
- [Clusters](#clusters)
- [ImgDetectionExtended](#imgdetectionextended)
- [ImgDetectionsExtended](#imgdetectionsextended)
- [Keypoint](#keypoint)
- [Keypoints](#keypoints)
- [Line](#line)
- [Lines](#lines)
- [Map2d](#map2d)
- [Prediction](#prediction)
- [Predictions](#predictions)
- [SegmentationMask](#segmentationmask)
- [SegmentationMaskSAM](#segmentationmaskssam)
- [Message Types](#message-types)
- [Classifications](#classifications)
- [Attributes](#attributes)
- [Cluster](#cluster)
- [Attributes](#attributes-1)
- [Clusters](#clusters)
- [Attributes](#attributes-2)
- [ImgDetectionExtended](#imgdetectionextended)
- [Attributes](#attributes-3)
- [ImgDetectionsExtended](#imgdetectionsextended)
- [Attributes](#attributes-4)
- [Keypoint](#keypoint)
- [Attributes](#attributes-5)
- [Keypoints](#keypoints)
- [Attributes](#attributes-6)
- [Line](#line)
- [Attributes](#attributes-7)
- [Lines](#lines)
- [Attributes](#attributes-8)
- [Map2D](#map2d)
- [Attributes](#attributes-9)
- [Prediction](#prediction)
- [Attributes](#attributes-10)
- [Predictions](#predictions)
- [Attributes](#attributes-11)
- [SegmentationMask](#segmentationmask)
- [Attributes](#attributes-12)
- [SegmentationMasksSAM](#segmentationmaskssam)
- [Attributes](#attributes-13)

## Classifications

Expand Down Expand Up @@ -51,11 +66,7 @@ A class for storing image detections in (x_center, y_center, width, height) form

### Attributes

- **x_center** (float): The X coordinate of the center of the bounding box, relative to the input width.
- **y_center** (float): The Y coordinate of the center of the bounding box, relative to the input height.
- **width** (float): The width of the bounding box, relative to the input width.
- **height** (float): The height of the bounding box, relative to the input height.
- **angle** (float): The angle of the bounding box expressed in degrees.
- **rotated_rect** (dai.RotatedRect): A depthai object for storing the roated bounding box information. The bounding box is stored as x_center, y_center, width, height, angle in degrees.
- **confidence** (float): Confidence of the detection.
- **label** (int): Label of the detection.
- **keypoints** (List\[[Keypoint](#keypoint)\]): Keypoints of the detection.
Expand Down
21 changes: 13 additions & 8 deletions depthai_nodes/ml/messages/creators/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,22 @@ def create_detection_message(
detections = []
for detection_idx in range(n_bboxes):
detection = ImgDetectionExtended()
x_center, y_center, width, height = bboxes[detection_idx]
detection.x_center = x_center.item()
detection.y_center = y_center.item()
detection.width = width.item()
detection.height = height.item()
detection.confidence = scores[detection_idx].item()

x_center, y_center, width, height = bboxes[detection_idx]
angle = 0
if angles is not None:
detection.angle = angles[detection_idx].item()
angle = float(angles[detection_idx])
detection.rotated_rect = (
float(x_center),
float(y_center),
float(width),
float(height),
angle,
)
detection.confidence = float(scores[detection_idx])

if labels is not None:
detection.label = labels[detection_idx].item()
detection.label = int(labels[detection_idx])
if keypoints is not None:
if keypoints_scores is not None:
detection.keypoints = transform_to_keypoints(
Expand Down
212 changes: 42 additions & 170 deletions depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

import depthai as dai
import numpy as np
Expand All @@ -10,161 +10,51 @@

class ImgDetectionExtended(dai.Buffer):
"""A class for storing image detections in (x_center, y_center, width, height)
format with additional angle and keypoints.
format with additional angle, label and keypoints.

Attributes
----------
x_center: float
The X coordinate of the center of the bounding box, relative to the input width.
y_center: float
The Y coordinate of the center of the bounding box, relative to the input height.
width: float
The width of the bounding box, relative to the input width.
height: float
The height of the bounding box, relative to the input height.
angle: float
The angle of the bounding box expressed in degrees.
rotated_rect: dai.RotatedRect
Rotated rectangle object defined by the center, width, height and angle in degrees.
confidence: float
Confidence of the detection.
label: int
Label of the detection.
label_name: str
The corresponding label name if available.
keypoints: List[Keypoint]
Keypoints of the detection.
"""

def __init__(self):
"""Initializes the ImgDetectionExtended object."""
super().__init__()
self._x_center: float
self._y_center: float
self._width: float
self._height: float

self._angle: float = 0.0
self._rotated_rect: dai.RotatedRect
self._confidence: float = -1.0
self._label: int = -1
self._label_name: str = ""
self._keypoints: List[Keypoint] = []

@property
def x_center(self) -> float:
"""Returns the X coordinate of the center of the bounding box.
def rotated_rect(self) -> dai.RotatedRect:
"""Returns the rotated rectangle representing the bounding box.

@return: X coordinate of the center of the bounding box.
@rtype: float
@return: Rotated rectangle object
@rtype: dai.RotatedRect
"""
return self._x_center
return self._rotated_rect

@x_center.setter
def x_center(self, value: float):
"""Sets the X coordinate of the center of the bounding box.
@rotated_rect.setter
def rotated_rect(self, rectangle: Tuple[float, float, float, float, float]):
"""Sets the rotated rectangle of the bounding box.

@param value: X coordinate of the center of the bounding box.
@type value: float
@raise TypeError: If value is not a float.
@raise ValueError: If value is not between 0 and 1.
@param value: Tuple of (x_center, y_center, width, height, angle).
@type value: tuple[float, float, float, float, float]
"""
if not isinstance(value, float):
raise TypeError("X center must be a float.")
if value < 0 or value > 1:
raise ValueError("X center must be between 0 and 1.")
self._x_center = value
center = dai.Point2f(rectangle[0], rectangle[1])
size = dai.Size2f(rectangle[2], rectangle[3])

@property
def y_center(self) -> float:
"""Returns the Y coordinate of the center of the bounding box.

@return: Y coordinate of the center of the bounding box.
@rtype: float
"""
return self._y_center

@y_center.setter
def y_center(self, value: float):
"""Sets the Y coordinate of the center of the bounding box.

@param value: Y coordinate of the center of the bounding box.
@type value: float
@raise TypeError: If value is not a float.
@raise ValueError: If value is not between 0 and 1.
"""
if not isinstance(value, float):
raise TypeError("Y center must be a float.")
if value < 0 or value > 1:
raise ValueError("Y center must be between 0 and 1.")
self._y_center = value

@property
def width(self) -> float:
"""Returns the width of the bounding box.

@return: Width of the bounding box.
@rtype: float
"""
return self._width

@width.setter
def width(self, value: float):
"""Sets the width of the bounding box.

@param value: Width of the bounding box.
@type value: float
@raise TypeError: If value is not a float.
@raise ValueError: If value is not between 0 and 1.
"""
if not isinstance(value, float):
raise TypeError("Width must be a float.")
if value < 0 or value > 1:
raise ValueError("Width must be between 0 and 1.")

self._width = value

@property
def height(self) -> float:
"""Returns the height of the bounding box.

@return: Height of the bounding box.
@rtype: float
"""
return self._height

@height.setter
def height(self, value: float):
"""Sets the height of the bounding box.

@param value: Height of the bounding box.
@type value: float
@raise TypeError: If value is not a float.
@raise ValueError: If value is not between 0 and 1.
"""
if not isinstance(value, float):
raise TypeError("Height must be a float.")
if value < 0 or value > 1:
raise ValueError("Height must be between 0 and 1.")
self._height = value

@property
def angle(self) -> float:
"""Returns the angle of the bounding box.

@return: Angle of the bounding box.
@rtype: float
"""
return self._angle

@angle.setter
def angle(self, value: float):
"""Sets the angle of the bounding box.

@param value: Angle of the bounding box.
@type value: float
@raise TypeError: If value is not a float.
@raise TypeError: If value is not between -360 and 360.
"""
if not isinstance(value, float):
raise TypeError("Angle must be a float.")
if value < -360 or value > 360:
raise TypeError("Angle must be between -360 and 360 degrees.")
self._angle = value
self._rotated_rect = dai.RotatedRect(center, size, rectangle[4])

@property
def confidence(self) -> float:
Expand Down Expand Up @@ -211,6 +101,27 @@ def label(self, value: int):
raise TypeError("Label must be an integer.")
self._label = value

@property
def label_name(self) -> str:
"""Returns the label name of the detection.

@return: Label name of the detection.
@rtype: str
"""
return self._label_name

@label_name.setter
def label_name(self, value: str):
"""Sets the label name of the detection.

@param value: Label name of the detection.
@type value: str
@raise TypeError: If value is not a string.
"""
if not isinstance(value, str):
raise TypeError("Label name must be a string.")
self._label_name = value

@property
def keypoints(
self,
Expand Down Expand Up @@ -240,45 +151,6 @@ def keypoints(
raise ValueError("Keypoints must be a list of Keypoint objects.")
self._keypoints = value

def get_xyxy_bbox_points(self) -> NDArray[np.float32]:
"""Returns the axis-aligned [x1, y1, x2, y2] bounding box points. It does not
take into account the angle of the bounding box.

@return: Bounding box points.
@rtype: NDArray[np.float32]
"""
x1 = self.x_center - self.width / 2
y1 = self.y_center - self.height / 2
x2 = self.x_center + self.width / 2
y2 = self.y_center + self.height / 2
return np.array([x1, y1, x2, y2], dtype=np.float32)

def get_bbox_points(self) -> NDArray[np.float32]:
"""Returns the bounding box points in the format [[x1,y1], [x2,y2], [x3,y3],
[x4,y4]]. Starting from the top-left corner and going clockwise. This is useful
for drawing rotated bounding boxes.

@return: Rotated bounding box points.
@rtype: NDArray[np.float32]
"""

angle = np.radians(self.angle)
x_center, y_center, w, h = self.x_center, self.y_center, self.width, self.height

w_half, h_half = w / 2, h / 2

corners = np.array(
[[-w_half, -h_half], [w_half, -h_half], [w_half, h_half], [-w_half, h_half]]
)

cos_a, sin_a = np.cos(angle), np.sin(angle)
rotation_matrix = np.array([[cos_a, -sin_a], [sin_a, cos_a]])
rotated_corners = np.dot(corners, rotation_matrix) + np.array(
[x_center, y_center]
)

return rotated_corners


class ImgDetectionsExtended(dai.Buffer):
"""ImgDetectionsExtended class for storing image detections with keypoints.
Expand Down
5 changes: 3 additions & 2 deletions depthai_nodes/ml/parsers/mediapipe_palm_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ def run(self):
scores = None

for tensor_name in all_tensors:
tensor = output.getTensor(tensor_name, dequantize=True).astype(
np.float32
tensor = np.array(
output.getTensor(tensor_name, dequantize=True), dtype=np.float32
)

if bboxes is None:
bboxes = tensor
scores = tensor
Expand Down
1 change: 1 addition & 0 deletions depthai_nodes/parser_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def build(self, nn_archive: dai.NNArchive, head_index: int = None) -> Dict:
"""

heads = nn_archive.getConfig().model.heads

indexes = range(len(heads))

if len(heads) == 0:
Expand Down
Loading
Loading