Skip to content

Commit

Permalink
Add visualization methods to all message types (#141)
Browse files Browse the repository at this point in the history
* added visualizationMessages

* timestamps and removed duplicated code
  • Loading branch information
aljazkonec1 authored Dec 10, 2024
1 parent f11ad94 commit 200bf4f
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 4 deletions.
6 changes: 6 additions & 0 deletions depthai_nodes/ml/helpers/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import depthai as dai

OUTLINE_COLOR = dai.Color(1.0, 0.5, 0.5, 1.0)
TEXT_COLOR = dai.Color(0.5, 0.5, 1.0, 1.0)
BACKGROUND_COLOR = dai.Color(1.0, 1.0, 0.5, 1.0)
KEYPOINT_COLOR = dai.Color(1.0, 0.35, 0.367, 1.0)
27 changes: 27 additions & 0 deletions depthai_nodes/ml/messages/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import numpy as np
from numpy.typing import NDArray

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
TEXT_COLOR,
)


class Classifications(dai.Buffer):
"""Classification class for storing the classes and their respective scores.
Expand Down Expand Up @@ -119,3 +124,25 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for classification.
The message adds the top five classes and their scores to the right side of the
image.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for i in range(len(self._scores)):
text = dai.TextAnnotation()
text.position = dai.Point2f(1.05, 0.1 + i * 0.1)
text.text = f"{self._classes[i]} {self._scores[i] * 100:.0f}%"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
31 changes: 30 additions & 1 deletion depthai_nodes/ml/messages/clusters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List

import cv2
import depthai as dai
import numpy as np


class Cluster(dai.Buffer):
Expand All @@ -18,7 +20,7 @@ def __init__(self):
"""Initializes the Cluster object."""
super().__init__()
self._label: int = None
self.points: List[dai.Point2f] = []
self._points: List[dai.Point2f] = []

@property
def label(self) -> int:
Expand Down Expand Up @@ -131,3 +133,30 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Creates a default visualization message for clusters and colors each one
separately."""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

num_clusters = len(self.clusters)
color_mask = np.array(range(0, 255, 255 // num_clusters), dtype=np.uint8)
color_mask = cv2.applyColorMap(color_mask, cv2.COLORMAP_RAINBOW)
color_mask = color_mask / 255
color_mask = color_mask.reshape(-1, 3)

for i, cluster in enumerate(self.clusters):
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.POINTS
pointsAnnotation.points = dai.VectorPoint2f(cluster.points)
r, g, b = color_mask[i]
color = dai.Color(r, g, b)
pointsAnnotation.outlineColor = color
pointsAnnotation.fillColor = color
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
55 changes: 54 additions & 1 deletion depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
import numpy as np
from numpy.typing import NDArray

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
KEYPOINT_COLOR,
OUTLINE_COLOR,
TEXT_COLOR,
)

from .keypoints import Keypoint
from .segmentation import SegmentationMask

Expand Down Expand Up @@ -199,7 +206,7 @@ def detections(self, value: List[ImgDetectionExtended]):
self._detections = value

@property
def masks(self) -> NDArray[np.int8]:
def masks(self) -> NDArray[np.int16]:
"""Returns the segmentation masks stored in a single numpy array.
@return: Segmentation masks.
Expand Down Expand Up @@ -253,3 +260,49 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()
transformation = self.transformation
w, h = 1, 1
if transformation is not None: # remove once RVC2 supports ImgTransformation
w, h = transformation.getSize()

for detection in self.detections:
detection: ImgDetectionExtended = detection
rotated_rect = detection.rotated_rect
rotated_rect = rotated_rect.denormalize(w, h)
points = rotated_rect.getPoints()
points = [dai.Point2f(point.x / w, point.y / h) for point in points]
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_LOOP
pointsAnnotation.points = dai.VectorPoint2f(points)
pointsAnnotation.outlineColor = OUTLINE_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

text = dai.TextAnnotation()
text.position = points[0]
text.text = f"{detection.label_name} {int(detection.confidence * 100)}%"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

if len(detection.keypoints) > 0:
keypoints = [
dai.Point2f(keypoint.x, keypoint.y)
for keypoint in detection.keypoints
]
keypointAnnotation = dai.PointsAnnotation()
keypointAnnotation.type = dai.PointsAnnotationType.POINTS
keypointAnnotation.points = dai.VectorPoint2f(keypoints)
keypointAnnotation.outlineColor = KEYPOINT_COLOR
keypointAnnotation.fillColor = KEYPOINT_COLOR
keypointAnnotation.thickness = 2
annotation.points.append(keypointAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
19 changes: 19 additions & 0 deletions depthai_nodes/ml/messages/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import KEYPOINT_COLOR


class Keypoint(dai.Buffer):
"""Keypoint class for storing a keypoint.
Expand Down Expand Up @@ -185,3 +187,20 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Creates a default visualization message for the keypoints."""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()
keypoints = [dai.Point2f(keypoint.x, keypoint.y) for keypoint in self.keypoints]
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.POINTS
pointsAnnotation.points = dai.VectorPoint2f(keypoints)
pointsAnnotation.outlineColor = KEYPOINT_COLOR
pointsAnnotation.fillColor = KEYPOINT_COLOR
pointsAnnotation.thickness = 2
annotation.points.append(pointsAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
24 changes: 24 additions & 0 deletions depthai_nodes/ml/messages/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import OUTLINE_COLOR


class Line(dai.Buffer):
"""Line class for storing a line.
Expand Down Expand Up @@ -158,3 +160,25 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for lines.
The message adds lines to the image.
"""
img_annotation = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for line in self.lines:
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_STRIP
pointsAnnotation.points = dai.VectorPoint2f(
[line.start_point, line.end_point]
)
pointsAnnotation.outlineColor = OUTLINE_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotation.annotations.append(annotation)
img_annotation.setTimestamp(self.getTimestamp())
return img_annotation
15 changes: 14 additions & 1 deletion depthai_nodes/ml/messages/map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import depthai as dai
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -92,10 +93,22 @@ def transformation(self, value: dai.ImgTransformation):
@type value: dai.ImgTransformation
@raise TypeError: If value is not a dai.ImgTransformation object.
"""

if value is not None:
if not isinstance(value, dai.ImgTransformation):
raise TypeError(
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)

self._transformation = value

def getVisualizationMessage(self) -> dai.ImgFrame:
"""Returns default visualization message for 2D maps in the form of a
colormapped image."""
img_frame = dai.ImgFrame()
mask = self._map.copy()
if np.any(mask < 1):
mask = mask * 255
mask = mask.astype(np.uint8)

colored_mask = cv2.applyColorMap(mask, cv2.COLORMAP_PLASMA)
return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
TEXT_COLOR,
)


class Prediction(dai.Buffer):
"""Prediction class for storing a prediction.
Expand Down Expand Up @@ -117,3 +122,24 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns the visualization message for the predictions.
The message adds text representing the predictions to the right of the image.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for i, prediction in enumerate(self.predictions):
text = dai.TextAnnotation()
text.position = dai.Point2f(1.05, 0.1 + i * 0.1)
text.text = f"{prediction.prediction:.2f}"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import depthai as dai
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -76,3 +77,28 @@ def transformation(self, value: dai.ImgTransformation):
f"Transformation must be a dai.ImgTransformation object, instead got {type(value)}."
)
self._transformation = value

def getVisualizationMessage(self) -> dai.ImgFrame:
"""Returns the default visualization message for segmentation masks."""
img_frame = dai.ImgFrame()
mask = self._mask.copy()

unique_values = np.unique(mask[mask >= 0])
scaled_mask = np.zeros_like(mask, dtype=np.uint8)

if unique_values.size == 0:
return img_frame.setCvFrame(scaled_mask, dai.ImgFrame.Type.BGR888i)

min_val, max_val = unique_values.min(), unique_values.max()

if min_val == max_val:
scaled_mask = np.ones_like(mask, dtype=np.uint8) * 255
else:
scaled_mask = ((mask - min_val) / (max_val - min_val) * 255).astype(
np.uint8
)
scaled_mask[mask == -1] = 0
colored_mask = cv2.applyColorMap(scaled_mask, cv2.COLORMAP_RAINBOW)
colored_mask[mask == -1] = [0, 0, 0]

return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
2 changes: 1 addition & 1 deletion depthai_nodes/ml/parsers/mediapipe_palm_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def run(self):
angles = np.round(angles, 0)

detections_msg = create_detection_message(
bboxes=bboxes, scores=scores, angles=angles
bboxes=bboxes, scores=scores, angles=angles, keypoints=points
)
detections_msg.setTimestamp(output.getTimestamp())
detections_msg.transformation = output.getTransformation()
Expand Down

0 comments on commit 200bf4f

Please sign in to comment.