Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adjust ParsingNeuralNetwork to build model directly from the model slug #134

Merged
merged 11 commits into from
Nov 21, 2024
47 changes: 19 additions & 28 deletions depthai_nodes/parsing_neural_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Union, overload
from typing import Dict, Union

import depthai as dai

Expand Down Expand Up @@ -39,22 +39,10 @@ def __init__(self, *args, **kwargs) -> None:
self._nn = self._pipeline.create(dai.node.NeuralNetwork)
self._parsers: Dict[int, BaseParser] = {}

@overload
def build(
self, input: dai.Node.Output, nn_source: dai.NNModelDescription, fps: int
) -> "ParsingNeuralNetwork":
...

@overload
def build(
self, input: dai.Node.Output, nn_source: dai.NNArchive, fps: int
) -> "ParsingNeuralNetwork":
...

def build(
self,
input: dai.Node.Output,
nn_source: Union[dai.NNModelDescription, dai.NNArchive],
slug: str,
fps: int = None,
) -> "ParsingNeuralNetwork":
"""Builds the underlying NeuralNetwork node and creates parser nodes for each
Expand All @@ -63,9 +51,8 @@ def build(
@param input: Node's input. It is a linking point to which the NeuralNetwork is
linked. It accepts the output of a Camera node.
@type input: Node.Input
@param nn_source: NNModelDescription object containing the HubAI model
descriptors, or NNArchive object of the model.
@type nn_source: Union[dai.NNModelDescription, dai.NNArchive]
@param slug: HubAI model slug.
@type slug: str
@param fps: FPS limit for the model runtime.
@type fps: int
@return: Returns the ParsingNeuralNetwork object.
Expand All @@ -74,21 +61,25 @@ def build(
object.
"""

if isinstance(nn_source, dai.NNModelDescription):
if not nn_source.platform:
nn_source.platform = (
self.getParentPipeline().getDefaultDevice().getPlatformAsString()
if isinstance(slug, str):
if ":" not in slug:
raise ValueError(
"Slug must be in the format <model_slug>:<model_version_slug>."
)
self._nn_archive = dai.NNArchive(dai.getModelFromZoo(nn_source))
elif isinstance(nn_source, dai.NNArchive):
self._nn_archive = nn_source
else:
raise ValueError(
"nn_source must be either a NNModelDescription or NNArchive"
model_slug, model_version_slug = slug.split(":")
model_description = dai.NNModelDescription(
modelSlug=model_slug,
modelVersionSlug=model_version_slug,
)
model_description.platform = (
self.getParentPipeline().getDefaultDevice().getPlatformAsString()
)
self._nn_archive = dai.NNArchive(dai.getModelFromZoo(model_description))
else:
raise ValueError("Slug must be a string.")

kwargs = {"fps": fps} if fps else {}
self._nn.build(input, nn_source, **kwargs)
self._nn.build(input, self._nn_archive, **kwargs)

self._updateParsers(self._nn_archive)
return self
Expand Down
Loading