Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization methods to all message types #141

Merged
merged 23 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3a337e6
added visualizationMessage to some messages
aljazkonec1 Nov 18, 2024
a953c9c
added visualizationMessage to some messages
aljazkonec1 Nov 18, 2024
4d3a295
remove print
aljazkonec1 Nov 19, 2024
012e165
Merge branch 'feat/detection_visualization_node' of https://github.co…
aljazkonec1 Nov 19, 2024
a185ab8
remove print
aljazkonec1 Nov 19, 2024
e5d057b
added annotations to all message types
aljazkonec1 Nov 26, 2024
00fa3c1
timestamps and removed duplicated code
aljazkonec1 Nov 27, 2024
e63f66d
classification updates
aljazkonec1 Nov 27, 2024
23e6c55
added dai.ImgTransformations attributes for messages.
aljazkonec1 Nov 28, 2024
905f522
rebase
aljazkonec1 Nov 18, 2024
850cedf
remove print
aljazkonec1 Nov 19, 2024
43be54c
rebase
aljazkonec1 Nov 18, 2024
73ebfcf
remove print
aljazkonec1 Nov 19, 2024
b4ec282
added annotations to all message types
aljazkonec1 Nov 28, 2024
f152ba0
Merge branch 'feat/detection_visualization_node' of https://github.co…
aljazkonec1 Nov 28, 2024
4432afd
fixed shearing of frames.
aljazkonec1 Nov 29, 2024
581fbf8
Merge branch 'main' into feat/detection_visualization_node
aljazkonec1 Nov 29, 2024
9cf6790
added points back
aljazkonec1 Nov 29, 2024
13a9ee0
Merge branch 'feat/detection_visualization_node' of https://github.co…
aljazkonec1 Nov 29, 2024
605cea2
Updated annotations
aljazkonec1 Dec 9, 2024
73e0651
pre-commit error
aljazkonec1 Dec 9, 2024
d037f3c
Merge branch 'main' into feat/detection_visualization_node
aljazkonec1 Dec 9, 2024
3f4ff72
pre-commit
aljazkonec1 Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions depthai_nodes/ml/helpers/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import depthai as dai

OUTLINE_COLOR = dai.Color(1.0, 0.5, 0.5, 1.0)
TEXT_COLOR = dai.Color(0.5, 0.5, 1.0, 1.0)
BACKGROUND_COLOR = dai.Color(1.0, 1.0, 0.5, 1.0)
KEYPOINT_COLOR = dai.Color(1.0, 0.35, 0.367, 1.0)
27 changes: 27 additions & 0 deletions depthai_nodes/ml/messages/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import numpy as np
from numpy.typing import NDArray

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
TEXT_COLOR,
)


class Classifications(dai.Buffer):
"""Classification class for storing the classes and their respective scores.
Expand Down Expand Up @@ -119,3 +124,25 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for classification.
The message adds the top five classes and their scores to the right side of the
image.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for i in range(len(self._scores)):
text = dai.TextAnnotation()
text.position = dai.Point2f(1.05, 0.1 + i * 0.1)
text.text = f"{self._classes[i]} {self._scores[i] * 100:.0f}%"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
31 changes: 30 additions & 1 deletion depthai_nodes/ml/messages/clusters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List

import cv2
import depthai as dai
import numpy as np


class Cluster(dai.Buffer):
Expand All @@ -18,7 +20,7 @@ def __init__(self):
"""Initializes the Cluster object."""
super().__init__()
self._label: int = None
self.points: List[dai.Point2f] = []
self._points: List[dai.Point2f] = []

@property
def label(self) -> int:
Expand Down Expand Up @@ -131,3 +133,30 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Creates a default visualization message for clusters and colors each one
separately."""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

num_clusters = len(self.clusters)
color_mask = np.array(range(0, 255, 255 // num_clusters), dtype=np.uint8)
color_mask = cv2.applyColorMap(color_mask, cv2.COLORMAP_RAINBOW)
color_mask = color_mask / 255
color_mask = color_mask.reshape(-1, 3)

for i, cluster in enumerate(self.clusters):
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.POINTS
pointsAnnotation.points = dai.VectorPoint2f(cluster.points)
r, g, b = color_mask[i]
color = dai.Color(r, g, b)
pointsAnnotation.outlineColor = color
pointsAnnotation.fillColor = color
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
55 changes: 54 additions & 1 deletion depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
import numpy as np
from numpy.typing import NDArray

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
KEYPOINT_COLOR,
OUTLINE_COLOR,
TEXT_COLOR,
)

from .keypoints import Keypoint
from .segmentation import SegmentationMask

Expand Down Expand Up @@ -199,7 +206,7 @@ def detections(self, value: List[ImgDetectionExtended]):
self._detections = value

@property
def masks(self) -> NDArray[np.int8]:
def masks(self) -> NDArray[np.int16]:
"""Returns the segmentation masks stored in a single numpy array.
@return: Segmentation masks.
Expand Down Expand Up @@ -253,3 +260,49 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()
transformation = self.transformation
w, h = 1, 1
if transformation is not None: # remove once RVC2 supports ImgTransformation
w, h = transformation.getSize()

for detection in self.detections:
detection: ImgDetectionExtended = detection
rotated_rect = detection.rotated_rect
rotated_rect = rotated_rect.denormalize(w, h)
points = rotated_rect.getPoints()
points = [dai.Point2f(point.x / w, point.y / h) for point in points]
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_LOOP
pointsAnnotation.points = dai.VectorPoint2f(points)
pointsAnnotation.outlineColor = OUTLINE_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

text = dai.TextAnnotation()
text.position = points[0]
text.text = f"{detection.label_name} {int(detection.confidence * 100)}%"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

if len(detection.keypoints) > 0:
keypoints = [
dai.Point2f(keypoint.x, keypoint.y)
for keypoint in detection.keypoints
]
keypointAnnotation = dai.PointsAnnotation()
keypointAnnotation.type = dai.PointsAnnotationType.POINTS
keypointAnnotation.points = dai.VectorPoint2f(keypoints)
keypointAnnotation.outlineColor = KEYPOINT_COLOR
keypointAnnotation.fillColor = KEYPOINT_COLOR
keypointAnnotation.thickness = 2
annotation.points.append(keypointAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
19 changes: 19 additions & 0 deletions depthai_nodes/ml/messages/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import KEYPOINT_COLOR


class Keypoint(dai.Buffer):
"""Keypoint class for storing a keypoint.
Expand Down Expand Up @@ -185,3 +187,20 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Creates a default visualization message for the keypoints."""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()
keypoints = [dai.Point2f(keypoint.x, keypoint.y) for keypoint in self.keypoints]
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.POINTS
pointsAnnotation.points = dai.VectorPoint2f(keypoints)
pointsAnnotation.outlineColor = KEYPOINT_COLOR
pointsAnnotation.fillColor = KEYPOINT_COLOR
pointsAnnotation.thickness = 2
annotation.points.append(pointsAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
24 changes: 24 additions & 0 deletions depthai_nodes/ml/messages/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import OUTLINE_COLOR


class Line(dai.Buffer):
"""Line class for storing a line.
Expand Down Expand Up @@ -158,3 +160,25 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for lines.
The message adds lines to the image.
"""
img_annotation = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for line in self.lines:
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_STRIP
pointsAnnotation.points = dai.VectorPoint2f(
[line.start_point, line.end_point]
)
pointsAnnotation.outlineColor = OUTLINE_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotation.annotations.append(annotation)
img_annotation.setTimestamp(self.getTimestamp())
return img_annotation
15 changes: 14 additions & 1 deletion depthai_nodes/ml/messages/map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import depthai as dai
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -92,10 +93,22 @@ def transformation(self, value: dai.ImgTransformation):
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""

if value is not None:
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)

self._transformation = value

def getVisualizationMessage(self) -> dai.ImgFrame:
"""Returns default visualization message for 2D maps in the form of a
colormapped image."""
img_frame = dai.ImgFrame()
mask = self._map.copy()
if np.any(mask < 1):
mask = mask * 255
mask = mask.astype(np.uint8)

colored_mask = cv2.applyColorMap(mask, cv2.COLORMAP_PLASMA)
return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
TEXT_COLOR,
)


class Prediction(dai.Buffer):
"""Prediction class for storing a prediction.
Expand Down Expand Up @@ -117,3 +122,24 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns the visualization message for the predictions.
The message adds text representing the predictions to the right of the image.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for i, prediction in enumerate(self.predictions):
text = dai.TextAnnotation()
text.position = dai.Point2f(1.05, 0.1 + i * 0.1)
text.text = f"{prediction.prediction:.2f}"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import depthai as dai
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -76,3 +77,28 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgFrame:
"""Returns the default visualization message for segmentation masks."""
img_frame = dai.ImgFrame()
mask = self._mask.copy()

unique_values = np.unique(mask[mask >= 0])
scaled_mask = np.zeros_like(mask, dtype=np.uint8)

if unique_values.size == 0:
return img_frame.setCvFrame(scaled_mask, dai.ImgFrame.Type.BGR888i)

min_val, max_val = unique_values.min(), unique_values.max()

if min_val == max_val:
scaled_mask = np.ones_like(mask, dtype=np.uint8) * 255
else:
scaled_mask = ((mask - min_val) / (max_val - min_val) * 255).astype(
np.uint8
)
scaled_mask[mask == -1] = 0
colored_mask = cv2.applyColorMap(scaled_mask, cv2.COLORMAP_RAINBOW)
colored_mask[mask == -1] = [0, 0, 0]

return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
2 changes: 1 addition & 1 deletion depthai_nodes/ml/parsers/mediapipe_palm_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def run(self):
angles = np.round(angles, 0)

detections_msg = create_detection_message(
bboxes=bboxes, scores=scores, angles=angles
bboxes=bboxes, scores=scores, angles=angles, keypoints=points
)
detections_msg.setTimestamp(output.getTimestamp())
detections_msg.transformation = output.getTransformation()
Expand Down
Loading