Skip to content

Commit

Permalink
Merge pull request #12 from luxonis/feature/thermal_image_parsing
Browse files Browse the repository at this point in the history
Feature: Add Thermal Image Parser
  • Loading branch information
jkbmrz authored Jul 12, 2024
2 parents ca98e65 + a8017f0 commit c8e1dba
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .segmentation import create_segmentation_message
from .thermal import create_thermal_message
from .tracked_features import create_tracked_features_message

__all__ = [
Expand All @@ -14,4 +15,5 @@
"create_line_detection_message",
"create_tracked_features_message",
"create_keypoints_message",
"create_thermal_message",
]
38 changes: 38 additions & 0 deletions depthai_nodes/ml/messages/creators/thermal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import depthai as dai
import numpy as np


def create_thermal_message(thermal_image: np.array) -> dai.ImgFrame:
"""Creates a thermal image message in the form of an ImgFrame using the provided
thermal image array.
Args:
thermal_image (np.array): A NumPy array representing the thermal image with shape (CHW or HWC).
Returns:
dai.ImgFrame: An ImgFrame object containing the thermal information.
"""

if not isinstance(thermal_image, np.ndarray):
raise ValueError(f"Expected numpy array, got {type(thermal_image)}.")
if len(thermal_image.shape) != 3:
raise ValueError(f"Expected 3D input, got {len(thermal_image.shape)}D input.")

if thermal_image.shape[0] == 1:
thermal_image = thermal_image[0, :, :] # CHW to HW
elif thermal_image.shape[2] == 1:
thermal_image = thermal_image[:, :, 0] # HWC to HW
else:
raise ValueError(
"Unexpected image shape. Expected CHW or HWC, got", thermal_image.shape
)

thermal_image = thermal_image.astype(np.uint16)

imgFrame = dai.ImgFrame()
imgFrame.setFrame(thermal_image)
imgFrame.setWidth(thermal_image.shape[1])
imgFrame.setHeight(thermal_image.shape[0])
imgFrame.setType(dai.ImgFrame.Type.RAW16)

return imgFrame
2 changes: 2 additions & 0 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .scrfd import SCRFDParser
from .segmentation import SegmentationParser
from .superanimal_landmarker import SuperAnimalParser
from .thermal_image import ThermalImageParser
from .xfeat import XFeatParser
from .yunet import YuNetParser

Expand All @@ -22,4 +23,5 @@
"KeypointParser",
"MLSDParser",
"XFeatParser",
"ThermalImageParser",
]
35 changes: 35 additions & 0 deletions depthai_nodes/ml/parsers/thermal_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import depthai as dai

from ..messages.creators import create_thermal_message


class ThermalImageParser(dai.node.ThreadedHostNode):
def __init__(self):
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

def run(self):
"""Postprocessing logic for a model with thermal image output (e.g. UGSR-FA).
Returns:
dai.ImgFrame: uint16, HW thermal image.
"""

while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

output_layer_names = output.getAllLayerNames()
if len(output_layer_names) != 1:
raise ValueError(
f"Expected 1 output layer, got {len(output_layer_names)}."
)
output = output.getTensor(output_layer_names[0])

thermal_map = output[0]

thermal_message = create_thermal_message(thermal_map=thermal_map)
self.out.send(thermal_message)

0 comments on commit c8e1dba

Please sign in to comment.