Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix creators. #42

Merged
merged 6 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions depthai_nodes/ml/messages/creators/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,36 @@ def create_depth_message(
) -> dai.ImgFrame:
"""Create a DepthAI message for a depth map.

@param depth_map: A NumPy array representing the depth map with shape (HW).
@param depth_map: A NumPy array representing the depth map with shape HW or NHW/HWN.
Here N stands for batch dimension.
@type depth_map: np.array
@param depth_type: A string indicating the type of depth map. It can either be
'relative' or 'metric'.
@type depth_type: Literal['relative', 'metric']
@return: An ImgFrame object containing the depth information.
@rtype: dai.ImgFrame
@raise ValueError: If the depth map is not a NumPy array.
@raise ValueError: If the depth map is not 2D.
@raise ValueError: If the depth map is not 2D or 3D.
@raise ValueError: If the depth map shape is not NHW or HWN.
@raise ValueError: If the depth type is not 'relative' or 'metric'.
@raise NotImplementedError: If the depth type is 'metric'.
"""

if not isinstance(depth_map, np.ndarray):
raise ValueError(f"Expected numpy array, got {type(depth_map)}.")

if len(depth_map.shape) != 3:
raise ValueError(f"Expected 3D input, got {len(depth_map.shape)}D input.")

if depth_map.shape[0] == 1:
depth_map = depth_map[0, :, :] # CHW to HW
elif depth_map.shape[2] == 1:
depth_map = depth_map[:, :, 0] # HWC to HW
else:
raise ValueError(
f"Unexpected image shape. Expected CHW or HWC, got {depth_map.shape}."
)
if len(depth_map.shape) == 3:
if depth_map.shape[0] == 1:
depth_map = depth_map[0, :, :] # CHW to HW
elif depth_map.shape[2] == 1:
depth_map = depth_map[:, :, 0] # HWC to HW
else:
raise ValueError(
f"Unexpected image shape. Expected NHW or HWN, got {depth_map.shape}."
)

if len(depth_map.shape) != 2:
raise ValueError(f"Expected 2D input, got {len(depth_map.shape)}D input.")
raise ValueError(f"Expected 2D or 3D input, got {len(depth_map.shape)}D input.")

if depth_type == "relative":
data_type = dai.ImgFrame.Type.RAW16
Expand Down
85 changes: 45 additions & 40 deletions depthai_nodes/ml/messages/creators/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,61 +95,64 @@ def create_keypoints_message(
@param keypoints: Detected 2D or 3D keypoints of shape (N,2 or 3) meaning [...,[x, y],...] or [...,[x, y, z],...].
@type keypoints: np.ndarray or List[List[float]]
@param scores: Confidence scores of the detected keypoints.
@type scores: np.ndarray or List[float]
@type scores: Optional[np.ndarray or List[float]]
@param confidence_threshold: Confidence threshold of keypoint detections.
@type confidence_threshold: float
@type confidence_threshold: Optional[float]

@return: Keypoints message containing the detected keypoints.
@rtype: Keypoints

@raise ValueError: If the keypoints are not a numpy array or list.
@raise ValueError: If the keypoints are not of shape (N,2 or 3).
@raise ValueError: If the keypoints 2nd dimension is not of size E{2} or E{3}.
@raise ValueError: If the scores are not a numpy array or list.
@raise ValueError: If the scores are not of shape (N,).
@raise ValueError: If the keypoints and scores do not have the same length.
@raise ValueError: If scores and keypoints do not have the same length.
@raise ValueError: If score values are not floats.
@raise ValueError: If score values are not between 0 and 1.
@raise ValueError: If the confidence threshold is not a float.
@raise ValueError: If the confidence threshold is not provided when scores are provided.
@raise ValueError: If the confidence threshold is not between 0 and 1.
@raise ValueError: If the keypoints are not of shape (N,2 or 3).
@raise ValueError: If the keypoints 2nd dimension is not of size E{2} or E{3}.
"""

if not isinstance(keypoints, (np.ndarray, list)):
raise ValueError(
f"Keypoints should be numpy array or list, got {type(keypoints)}."
)

if not isinstance(scores, (np.ndarray, list)):
raise ValueError(f"Scores should be numpy array or list, got {type(scores)}.")

if len(keypoints) == 0:
raise ValueError("Keypoints should not be empty.")
if scores is not None:
if not isinstance(scores, (np.ndarray, list)):
raise ValueError(
f"Scores should be numpy array or list, got {type(scores)}."
)

if len(keypoints) != len(scores):
raise ValueError(
"Keypoints and scores should have the same length. Got {} keypoints and {} scores.".format(
len(keypoints), len(scores)
if len(keypoints) != len(scores):
raise ValueError(
"Keypoints and scores should have the same length. Got {} keypoints and {} scores.".format(
len(keypoints), len(scores)
)
)
)

if not all(isinstance(score, (float, np.floating)) for score in scores):
raise ValueError("Scores should only contain float values.")
if not all(0 <= score <= 1 for score in scores):
raise ValueError("Scores should only contain values between 0 and 1.")
if not all(isinstance(score, (float, np.floating)) for score in scores):
raise ValueError("Scores should only contain float values.")
if not all(0 <= score <= 1 for score in scores):
raise ValueError("Scores should only contain values between 0 and 1.")

if not isinstance(confidence_threshold, float):
raise ValueError(
f"The confidence_threshold should be float, got {type(confidence_threshold)}."
)
if confidence_threshold is not None:
if not isinstance(confidence_threshold, float):
raise ValueError(
f"The confidence_threshold should be float, got {type(confidence_threshold)}."
)

if not (0 <= confidence_threshold <= 1):
raise ValueError(
f"The confidence_threshold should be between 0 and 1, got confidence_threshold {confidence_threshold}."
)
if not (0 <= confidence_threshold <= 1):
raise ValueError(
f"The confidence_threshold should be between 0 and 1, got confidence_threshold {confidence_threshold}."
)

dimension = len(keypoints[0])
if dimension != 2 and dimension != 3:
raise ValueError(
f"All keypoints should be of dimension 2 or 3, got dimension {dimension}."
)
if len(keypoints) != 0:
dimension = len(keypoints[0])
if dimension != 2 and dimension != 3:
raise ValueError(
f"All keypoints should be of dimension 2 or 3, got dimension {dimension}."
)

if isinstance(keypoints, list):
for keypoint in keypoints:
Expand All @@ -168,14 +171,16 @@ def create_keypoints_message(
)

keypoints = np.array(keypoints)
scores = np.array(scores)
if scores is not None:
scores = np.array(scores)

if len(keypoints.shape) != 2:
raise ValueError(
f"Keypoints should be of shape (N,2 or 3) got {keypoints.shape}."
)
if len(keypoints) != 0:
if len(keypoints.shape) != 2:
raise ValueError(
f"Keypoints should be of shape (N,2 or 3) got {keypoints.shape}."
)

use_3d = keypoints.shape[1] == 3
use_3d = keypoints.shape[1] == 3

keypoints_msg = Keypoints()
points = []
Expand Down
3 changes: 2 additions & 1 deletion examples/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import depthai as dai
from utils.arguments import initialize_argparser, parse_model_slug
from utils.model import get_input_shape, get_model_from_hub, get_parser, setup_parser
from utils.model import get_input_shape, get_model_from_hub, get_parser
from utils.parser import setup_parser
from visualization.visualize import visualize

# Initialize the argument parser
Expand Down
106 changes: 1 addition & 105 deletions examples/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@

import depthai as dai

from depthai_nodes.ml.parsers import (
ClassificationParser,
KeypointParser,
MonocularDepthParser,
SCRFDParser,
SegmentationParser,
XFeatParser,
)
from depthai_nodes.ml.parsers import *


def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive:
Expand Down Expand Up @@ -92,100 +85,3 @@ def get_input_shape(nn_archive: dai.NNArchive) -> Tuple[int, int]:
print(f"Input shape: {input_shape}")

return input_shape


def setup_scrfd_parser(parser: SCRFDParser, params: dict):
"""Setup the SCRFD parser with the required metadata."""
try:
num_anchors = params["num_anchors"]
feat_stride_fpn = params["feat_stride_fpn"]
parser.setNumAnchors(num_anchors)
parser.setFeatStrideFPN(feat_stride_fpn)
except Exception:
print(
"This NN archive does not have required metadata for SCRFDParser. Skipping setup..."
)


def setup_segmentation_parser(parser: SegmentationParser, params: dict):
"""Setup the segmentation parser with the required metadata."""
try:
background_class = params["background_class"]
parser.setBackgroundClass(background_class)
except Exception:
print(
"This NN archive does not have required metadata for SegmentationParser. Skipping setup..."
)


def setup_keypoint_parser(parser: KeypointParser, params: dict):
"""Setup the keypoint parser with the required metadata."""
try:
num_keypoints = params["n_keypoints"]
scale_factor = params["scale_factor"]
parser.setNumKeypoints(num_keypoints)
parser.setScaleFactor(scale_factor)
except Exception:
print(
"This NN archive does not have required metadata for KeypointParser. Skipping setup..."
)


def setup_classification_parser(parser: ClassificationParser, params: dict):
"""Setup the classification parser with the required metadata."""
try:
classes = params["classes"]
is_softmax = params["is_softmax"]
parser.setClasses(classes)
parser.setSoftmax(is_softmax)
except Exception:
print(
"This NN archive does not have required metadata for ClassificationParser. Skipping setup..."
)


def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict):
"""Setup the monocular depth parser with the required metadata."""
try:
depth_type = params["depth_type"]
if depth_type == "relative":
parser.setRelativeDepthType()
else:
parser.setMetricDepthType()
except Exception:
print(
"This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..."
)


def setup_xfeat_parser(parser: XFeatParser, params: dict):
"""Setup the XFeat parser with the required metadata."""
try:
input_size = params["input_size"]
parser.setInputSize(input_size)
parser.setOriginalSize(input_size)
except Exception:
print(
"This NN archive does not have required metadata for XFeatParser. Skipping setup..."
)


def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str):
"""Setup the parser with the NN archive."""

extraParams = (
nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams
)

if parser_name == "SCRFDParser":
setup_scrfd_parser(parser, extraParams)
elif parser_name == "SegmentationParser":
setup_segmentation_parser(parser, extraParams)
elif parser_name == "KeypointParser":
setup_keypoint_parser(parser, extraParams)
elif parser_name == "ClassificationParser":
setup_classification_parser(parser, extraParams)
elif parser_name == "MonocularDepthParser":
setup_monocular_depth_parser(parser, extraParams)
elif parser_name == "XFeatParser":
setup_xfeat_parser(parser, extraParams)
Loading
Loading