diff --git a/depthai_nodes/ml/messages/creators/depth.py b/depthai_nodes/ml/messages/creators/depth.py index 1b04b893..24c0f95c 100644 --- a/depthai_nodes/ml/messages/creators/depth.py +++ b/depthai_nodes/ml/messages/creators/depth.py @@ -11,7 +11,8 @@ def create_depth_message( ) -> dai.ImgFrame: """Create a DepthAI message for a depth map. - @param depth_map: A NumPy array representing the depth map with shape (HW). + @param depth_map: A NumPy array representing the depth map with shape HW or NHW/HWN. + Here N stands for batch dimension. @type depth_map: np.array @param depth_type: A string indicating the type of depth map. It can either be 'relative' or 'metric'. @@ -19,27 +20,27 @@ def create_depth_message( @return: An ImgFrame object containing the depth information. @rtype: dai.ImgFrame @raise ValueError: If the depth map is not a NumPy array. - @raise ValueError: If the depth map is not 2D. + @raise ValueError: If the depth map is not 2D or 3D. + @raise ValueError: If the depth map shape is not NHW or HWN. @raise ValueError: If the depth type is not 'relative' or 'metric'. + @raise NotImplementedError: If the depth type is 'metric'. """ if not isinstance(depth_map, np.ndarray): raise ValueError(f"Expected numpy array, got {type(depth_map)}.") - if len(depth_map.shape) != 3: - raise ValueError(f"Expected 3D input, got {len(depth_map.shape)}D input.") - - if depth_map.shape[0] == 1: - depth_map = depth_map[0, :, :] # CHW to HW - elif depth_map.shape[2] == 1: - depth_map = depth_map[:, :, 0] # HWC to HW - else: - raise ValueError( - f"Unexpected image shape. Expected CHW or HWC, got {depth_map.shape}." - ) + if len(depth_map.shape) == 3: + if depth_map.shape[0] == 1: + depth_map = depth_map[0, :, :] # CHW to HW + elif depth_map.shape[2] == 1: + depth_map = depth_map[:, :, 0] # HWC to HW + else: + raise ValueError( + f"Unexpected image shape. Expected NHW or HWN, got {depth_map.shape}." + ) if len(depth_map.shape) != 2: - raise ValueError(f"Expected 2D input, got {len(depth_map.shape)}D input.") + raise ValueError(f"Expected 2D or 3D input, got {len(depth_map.shape)}D input.") if depth_type == "relative": data_type = dai.ImgFrame.Type.RAW16 diff --git a/depthai_nodes/ml/messages/creators/keypoints.py b/depthai_nodes/ml/messages/creators/keypoints.py index 309f6cc9..c887f134 100644 --- a/depthai_nodes/ml/messages/creators/keypoints.py +++ b/depthai_nodes/ml/messages/creators/keypoints.py @@ -95,21 +95,22 @@ def create_keypoints_message( @param keypoints: Detected 2D or 3D keypoints of shape (N,2 or 3) meaning [...,[x, y],...] or [...,[x, y, z],...]. @type keypoints: np.ndarray or List[List[float]] @param scores: Confidence scores of the detected keypoints. - @type scores: np.ndarray or List[float] + @type scores: Optional[np.ndarray or List[float]] @param confidence_threshold: Confidence threshold of keypoint detections. - @type confidence_threshold: float + @type confidence_threshold: Optional[float] @return: Keypoints message containing the detected keypoints. @rtype: Keypoints @raise ValueError: If the keypoints are not a numpy array or list. - @raise ValueError: If the keypoints are not of shape (N,2 or 3). - @raise ValueError: If the keypoints 2nd dimension is not of size E{2} or E{3}. @raise ValueError: If the scores are not a numpy array or list. - @raise ValueError: If the scores are not of shape (N,). - @raise ValueError: If the keypoints and scores do not have the same length. + @raise ValueError: If scores and keypoints do not have the same length. + @raise ValueError: If score values are not floats. + @raise ValueError: If score values are not between 0 and 1. @raise ValueError: If the confidence threshold is not a float. - @raise ValueError: If the confidence threshold is not provided when scores are provided. + @raise ValueError: If the confidence threshold is not between 0 and 1. + @raise ValueError: If the keypoints are not of shape (N,2 or 3). + @raise ValueError: If the keypoints 2nd dimension is not of size E{2} or E{3}. """ if not isinstance(keypoints, (np.ndarray, list)): @@ -117,39 +118,41 @@ def create_keypoints_message( f"Keypoints should be numpy array or list, got {type(keypoints)}." ) - if not isinstance(scores, (np.ndarray, list)): - raise ValueError(f"Scores should be numpy array or list, got {type(scores)}.") - - if len(keypoints) == 0: - raise ValueError("Keypoints should not be empty.") + if scores is not None: + if not isinstance(scores, (np.ndarray, list)): + raise ValueError( + f"Scores should be numpy array or list, got {type(scores)}." + ) - if len(keypoints) != len(scores): - raise ValueError( - "Keypoints and scores should have the same length. Got {} keypoints and {} scores.".format( - len(keypoints), len(scores) + if len(keypoints) != len(scores): + raise ValueError( + "Keypoints and scores should have the same length. Got {} keypoints and {} scores.".format( + len(keypoints), len(scores) + ) ) - ) - if not all(isinstance(score, (float, np.floating)) for score in scores): - raise ValueError("Scores should only contain float values.") - if not all(0 <= score <= 1 for score in scores): - raise ValueError("Scores should only contain values between 0 and 1.") + if not all(isinstance(score, (float, np.floating)) for score in scores): + raise ValueError("Scores should only contain float values.") + if not all(0 <= score <= 1 for score in scores): + raise ValueError("Scores should only contain values between 0 and 1.") - if not isinstance(confidence_threshold, float): - raise ValueError( - f"The confidence_threshold should be float, got {type(confidence_threshold)}." - ) + if confidence_threshold is not None: + if not isinstance(confidence_threshold, float): + raise ValueError( + f"The confidence_threshold should be float, got {type(confidence_threshold)}." + ) - if not (0 <= confidence_threshold <= 1): - raise ValueError( - f"The confidence_threshold should be between 0 and 1, got confidence_threshold {confidence_threshold}." - ) + if not (0 <= confidence_threshold <= 1): + raise ValueError( + f"The confidence_threshold should be between 0 and 1, got confidence_threshold {confidence_threshold}." + ) - dimension = len(keypoints[0]) - if dimension != 2 and dimension != 3: - raise ValueError( - f"All keypoints should be of dimension 2 or 3, got dimension {dimension}." - ) + if len(keypoints) != 0: + dimension = len(keypoints[0]) + if dimension != 2 and dimension != 3: + raise ValueError( + f"All keypoints should be of dimension 2 or 3, got dimension {dimension}." + ) if isinstance(keypoints, list): for keypoint in keypoints: @@ -168,14 +171,16 @@ def create_keypoints_message( ) keypoints = np.array(keypoints) - scores = np.array(scores) + if scores is not None: + scores = np.array(scores) - if len(keypoints.shape) != 2: - raise ValueError( - f"Keypoints should be of shape (N,2 or 3) got {keypoints.shape}." - ) + if len(keypoints) != 0: + if len(keypoints.shape) != 2: + raise ValueError( + f"Keypoints should be of shape (N,2 or 3) got {keypoints.shape}." + ) - use_3d = keypoints.shape[1] == 3 + use_3d = keypoints.shape[1] == 3 keypoints_msg = Keypoints() points = [] diff --git a/examples/main.py b/examples/main.py index 4428e6c9..99543219 100644 --- a/examples/main.py +++ b/examples/main.py @@ -1,6 +1,7 @@ import depthai as dai from utils.arguments import initialize_argparser, parse_model_slug -from utils.model import get_input_shape, get_model_from_hub, get_parser, setup_parser +from utils.model import get_input_shape, get_model_from_hub, get_parser +from utils.parser import setup_parser from visualization.visualize import visualize # Initialize the argument parser diff --git a/examples/utils/model.py b/examples/utils/model.py index 67a7863c..6881a54c 100644 --- a/examples/utils/model.py +++ b/examples/utils/model.py @@ -2,14 +2,7 @@ import depthai as dai -from depthai_nodes.ml.parsers import ( - ClassificationParser, - KeypointParser, - MonocularDepthParser, - SCRFDParser, - SegmentationParser, - XFeatParser, -) +from depthai_nodes.ml.parsers import * def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive: @@ -92,100 +85,3 @@ def get_input_shape(nn_archive: dai.NNArchive) -> Tuple[int, int]: print(f"Input shape: {input_shape}") return input_shape - - -def setup_scrfd_parser(parser: SCRFDParser, params: dict): - """Setup the SCRFD parser with the required metadata.""" - try: - num_anchors = params["num_anchors"] - feat_stride_fpn = params["feat_stride_fpn"] - parser.setNumAnchors(num_anchors) - parser.setFeatStrideFPN(feat_stride_fpn) - except Exception: - print( - "This NN archive does not have required metadata for SCRFDParser. Skipping setup..." - ) - - -def setup_segmentation_parser(parser: SegmentationParser, params: dict): - """Setup the segmentation parser with the required metadata.""" - try: - background_class = params["background_class"] - parser.setBackgroundClass(background_class) - except Exception: - print( - "This NN archive does not have required metadata for SegmentationParser. Skipping setup..." - ) - - -def setup_keypoint_parser(parser: KeypointParser, params: dict): - """Setup the keypoint parser with the required metadata.""" - try: - num_keypoints = params["n_keypoints"] - scale_factor = params["scale_factor"] - parser.setNumKeypoints(num_keypoints) - parser.setScaleFactor(scale_factor) - except Exception: - print( - "This NN archive does not have required metadata for KeypointParser. Skipping setup..." - ) - - -def setup_classification_parser(parser: ClassificationParser, params: dict): - """Setup the classification parser with the required metadata.""" - try: - classes = params["classes"] - is_softmax = params["is_softmax"] - parser.setClasses(classes) - parser.setSoftmax(is_softmax) - except Exception: - print( - "This NN archive does not have required metadata for ClassificationParser. Skipping setup..." - ) - - -def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict): - """Setup the monocular depth parser with the required metadata.""" - try: - depth_type = params["depth_type"] - if depth_type == "relative": - parser.setRelativeDepthType() - else: - parser.setMetricDepthType() - except Exception: - print( - "This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..." - ) - - -def setup_xfeat_parser(parser: XFeatParser, params: dict): - """Setup the XFeat parser with the required metadata.""" - try: - input_size = params["input_size"] - parser.setInputSize(input_size) - parser.setOriginalSize(input_size) - except Exception: - print( - "This NN archive does not have required metadata for XFeatParser. Skipping setup..." - ) - - -def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str): - """Setup the parser with the NN archive.""" - - extraParams = ( - nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams - ) - - if parser_name == "SCRFDParser": - setup_scrfd_parser(parser, extraParams) - elif parser_name == "SegmentationParser": - setup_segmentation_parser(parser, extraParams) - elif parser_name == "KeypointParser": - setup_keypoint_parser(parser, extraParams) - elif parser_name == "ClassificationParser": - setup_classification_parser(parser, extraParams) - elif parser_name == "MonocularDepthParser": - setup_monocular_depth_parser(parser, extraParams) - elif parser_name == "XFeatParser": - setup_xfeat_parser(parser, extraParams) diff --git a/examples/utils/parser.py b/examples/utils/parser.py new file mode 100644 index 00000000..939f45d7 --- /dev/null +++ b/examples/utils/parser.py @@ -0,0 +1,107 @@ +import depthai as dai + +from depthai_nodes.ml.parsers import ( + ClassificationParser, + KeypointParser, + MonocularDepthParser, + SCRFDParser, + SegmentationParser, + XFeatParser, +) + + +def setup_scrfd_parser(parser: SCRFDParser, params: dict): + """Setup the SCRFD parser with the required metadata.""" + try: + num_anchors = params["num_anchors"] + feat_stride_fpn = params["feat_stride_fpn"] + parser.setNumAnchors(num_anchors) + parser.setFeatStrideFPN(feat_stride_fpn) + except Exception: + print( + "This NN archive does not have required metadata for SCRFDParser. Skipping setup..." + ) + + +def setup_segmentation_parser(parser: SegmentationParser, params: dict): + """Setup the segmentation parser with the required metadata.""" + try: + background_class = params["background_class"] + parser.setBackgroundClass(background_class) + except Exception: + print( + "This NN archive does not have required metadata for SegmentationParser. Skipping setup..." + ) + + +def setup_keypoint_parser(parser: KeypointParser, params: dict): + """Setup the keypoint parser with the required metadata.""" + try: + num_keypoints = params["n_keypoints"] + scale_factor = params["scale_factor"] + parser.setNumKeypoints(num_keypoints) + parser.setScaleFactor(scale_factor) + except Exception: + print( + "This NN archive does not have required metadata for KeypointParser. Skipping setup..." + ) + + +def setup_classification_parser(parser: ClassificationParser, params: dict): + """Setup the classification parser with the required metadata.""" + try: + classes = params["classes"] + is_softmax = params["is_softmax"] + parser.setClasses(classes) + parser.setSoftmax(is_softmax) + except Exception: + print( + "This NN archive does not have required metadata for ClassificationParser. Skipping setup..." + ) + + +def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict): + """Setup the monocular depth parser with the required metadata.""" + try: + depth_type = params["depth_type"] + if depth_type == "relative": + parser.setRelativeDepthType() + else: + parser.setMetricDepthType() + except Exception: + print( + "This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..." + ) + + +def setup_xfeat_parser(parser: XFeatParser, params: dict): + """Setup the XFeat parser with the required metadata.""" + try: + input_size = params["input_size"] + parser.setInputSize(input_size) + parser.setOriginalSize(input_size) + except Exception: + print( + "This NN archive does not have required metadata for XFeatParser. Skipping setup..." + ) + + +def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str): + """Setup the parser with the NN archive.""" + + extraParams = ( + nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams + ) + + if parser_name == "SCRFDParser": + setup_scrfd_parser(parser, extraParams) + elif parser_name == "SegmentationParser": + setup_segmentation_parser(parser, extraParams) + elif parser_name == "KeypointParser": + setup_keypoint_parser(parser, extraParams) + elif parser_name == "ClassificationParser": + setup_classification_parser(parser, extraParams) + elif parser_name == "MonocularDepthParser": + setup_monocular_depth_parser(parser, extraParams) + elif parser_name == "XFeatParser": + setup_xfeat_parser(parser, extraParams) diff --git a/tests/unittests/test_creators/test_depth.py b/tests/unittests/test_creators/test_depth.py index 4a08d9e6..6e249a8b 100644 --- a/tests/unittests/test_creators/test_depth.py +++ b/tests/unittests/test_creators/test_depth.py @@ -21,14 +21,14 @@ def test_wrong_literal_type(): def test_not_3d_input(): - with pytest.raises(ValueError, match="Expected 3D input, got 1D input."): + with pytest.raises(ValueError, match="Expected 2D or 3D input, got 1D input."): create_depth_message(np.array([1, 2, 3]), "relative") def test_wrong_input_shape(): with pytest.raises( ValueError, - match=re.escape("Unexpected image shape. Expected CHW or HWC, got (3, 1, 3)."), + match=re.escape("Unexpected image shape. Expected NHW or HWN, got (3, 1, 3)."), ): create_depth_message( np.array([[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]), "relative" diff --git a/tests/unittests/test_creators/test_keypoints.py b/tests/unittests/test_creators/test_keypoints.py index 4472db45..47dfc151 100644 --- a/tests/unittests/test_creators/test_keypoints.py +++ b/tests/unittests/test_creators/test_keypoints.py @@ -31,16 +31,11 @@ def test_none_keypoints(): def test_none_scores(): - with pytest.raises( - ValueError, - match="Scores should be numpy array or list, got .", - ): - create_keypoints_message(np.array([[1.0, 2.0, 3.0]]), None, 0.8) + create_keypoints_message(np.array([[1.0, 2.0, 3.0]]), None, None) def test_empty_keypoints(): - with pytest.raises(ValueError, match="Keypoints should not be empty."): - create_keypoints_message([], [], 0.8) + create_keypoints_message([], [], 0.8) def test_keypoints_and_scores_length():