Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add Thermal Image Parser #12

Merged
merged 6 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .segmentation import create_segmentation_message
from .thermal import create_thermal_message
from .tracked_features import create_tracked_features_message

__all__ = [
Expand All @@ -14,4 +15,5 @@
"create_line_detection_message",
"create_tracked_features_message",
"create_keypoints_message",
"create_thermal_message",
]
38 changes: 38 additions & 0 deletions depthai_nodes/ml/messages/creators/thermal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import depthai as dai
import numpy as np


def create_thermal_message(thermal_image: np.array) -> dai.ImgFrame:
"""Creates a thermal image message in the form of an ImgFrame using the provided
thermal image array.

Args:
thermal_image (np.array): A NumPy array representing the thermal image with shape (CHW or HWC).

Returns:
dai.ImgFrame: An ImgFrame object containing the thermal information.
"""

if not isinstance(thermal_image, np.ndarray):
raise ValueError(f"Expected numpy array, got {type(thermal_image)}.")
if len(thermal_image.shape) != 3:
raise ValueError(f"Expected 3D input, got {len(thermal_image.shape)}D input.")

if thermal_image.shape[0] == 1:
thermal_image = thermal_image[0, :, :] # CHW to HW
elif thermal_image.shape[2] == 1:
thermal_image = thermal_image[:, :, 0] # HWC to HW
else:
raise ValueError(
"Unexpected image shape. Expected CHW or HWC, got", thermal_image.shape
)

thermal_image = thermal_image.astype(np.uint16)

imgFrame = dai.ImgFrame()
imgFrame.setFrame(thermal_image)
imgFrame.setWidth(thermal_image.shape[1])
imgFrame.setHeight(thermal_image.shape[0])
imgFrame.setType(dai.ImgFrame.Type.RAW16)

return imgFrame
2 changes: 2 additions & 0 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .scrfd import SCRFDParser
from .segmentation import SegmentationParser
from .superanimal_landmarker import SuperAnimalParser
from .thermal_image import ThermalImageParser
from .xfeat import XFeatParser
from .yunet import YuNetParser

Expand All @@ -22,4 +23,5 @@
"KeypointParser",
"MLSDParser",
"XFeatParser",
"ThermalImageParser",
]
35 changes: 35 additions & 0 deletions depthai_nodes/ml/parsers/thermal_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import depthai as dai

from ..messages.creators import create_thermal_message


class ThermalImageParser(dai.node.ThreadedHostNode):
def __init__(self):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

def run(self):
"""Postprocessing logic for a model with thermal image output (e.g. UGSR-FA).

Returns:
dai.ImgFrame: uint16, HW thermal image.
"""

while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

output_layer_names = output.getAllLayerNames()
if len(output_layer_names) != 1:
raise ValueError(
f"Expected 1 output layer, got {len(output_layer_names)}."
)
output = output.getTensor(output_layer_names[0])

thermal_map = output[0]

thermal_message = create_thermal_message(thermal_map=thermal_map)
self.out.send(thermal_message)