Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adjust ParsingNeuralNetwork to build model directly from the model slug #134

Merged
merged 11 commits into from
Nov 21, 2024
28 changes: 18 additions & 10 deletions depthai_nodes/parsing_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,16 @@ def build(
) -> "ParsingNeuralNetwork":
...

@overload
def build(
self, input: dai.Node.Output, nn_source: str, fps: int
) -> "ParsingNeuralNetwork":
...

def build(
self,
input: dai.Node.Output,
nn_source: Union[dai.NNModelDescription, dai.NNArchive],
nn_source: Union[dai.NNModelDescription, dai.NNArchive, str],
fps: int = None,
) -> "ParsingNeuralNetwork":
"""Builds the underlying NeuralNetwork node and creates parser nodes for each
Expand All @@ -63,9 +69,9 @@ def build(
@param input: Node's input. It is a linking point to which the NeuralNetwork is
linked. It accepts the output of a Camera node.
@type input: Node.Input
@param nn_source: NNModelDescription object containing the HubAI model
descriptors, or NNArchive object of the model.
@type nn_source: Union[dai.NNModelDescription, dai.NNArchive]

@param nn_source: NNModelDescription object containing the HubAI model descriptors, NNArchive object of the model, or HubAI model slug in form of <model_slug>:<model_version_slug> or <model_slug>:<model_version_slug>:<model_instance_hash>.
@type nn_source: Union[dai.NNModelDescription, dai.NNArchive, str]
@param fps: FPS limit for the model runtime.
@type fps: int
@return: Returns the ParsingNeuralNetwork object.
Expand All @@ -74,21 +80,23 @@ def build(
object.
"""

if isinstance(nn_source, dai.NNModelDescription):
platform = self.getParentPipeline().getDefaultDevice().getPlatformAsString()

if isinstance(nn_source, str):
nn_source = dai.NNModelDescription(nn_source)
if isinstance(nn_source, (dai.NNModelDescription, str)):
if not nn_source.platform:
nn_source.platform = (
self.getParentPipeline().getDefaultDevice().getPlatformAsString()
)
nn_source.platform = platform
self._nn_archive = dai.NNArchive(dai.getModelFromZoo(nn_source))
elif isinstance(nn_source, dai.NNArchive):
self._nn_archive = nn_source
else:
raise ValueError(
"nn_source must be either a NNModelDescription or NNArchive"
"nn_source must be either a NNModelDescription, NNArchive, or a string representing HubAI model slug."
)

kwargs = {"fps": fps} if fps else {}
self._nn.build(input, nn_source, **kwargs)
self._nn.build(input, self._nn_archive, **kwargs)

self._updateParsers(self._nn_archive)
return self
Expand Down
6 changes: 3 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ Prepare the model slug and run:

```
cd examples
python main.py -s <slug_of_your_model>
python main.py -m <slug_of_your_model>
```

For example:

```
python main.py -s yolov6-nano:coco-416x416
python main.py -m yolov6-nano:coco-416x416
```

Note that for running the examples you need RVC2 device connected.
Expand All @@ -29,7 +29,7 @@ If using OAK-D Lite, make sure to also set the FPS limit under 28.5.
For example:

```
python main.py -s yolov6-nano:coco-416x416 -fps 28
python main.py -m yolov6-nano:coco-416x416 -fps 28
```

Some models have small input sizes and requesting small image size from `Camera` is problematic so we request 4x bigger frame and resize it back down. During visualization image frame is resized back so some image quality is lost - only for visualization.
Expand Down
88 changes: 31 additions & 57 deletions examples/main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import depthai as dai
from utils.arguments import initialize_argparser, parse_fps_limit, parse_model_slug
from utils.model import get_input_shape, get_model_from_hub, get_parser
from utils.arguments import initialize_argparser
from utils.model import get_input_shape, get_nn_archive_from_hub, get_parser
from utils.xfeat import xfeat_mono, xfeat_stereo
from visualization.visualize import visualize

from depthai_nodes.parser_generator import ParserGenerator
from depthai_nodes import ParsingNeuralNetwork

# Initialize the argument parser
arg_parser, args = initialize_argparser()

# Parse the model slug
model_slug, model_version_slug = parse_model_slug(args)
fps_limit = parse_fps_limit(args)

# Get the model from the HubAI
nn_archive = get_model_from_hub(model_slug, model_version_slug)
# Parse the arguments
model = args.model
fps_limit = args.fps_limit

# Get the parser
parser_class, parser_name = get_parser(nn_archive)
nn_archive = get_nn_archive_from_hub(model)
_, parser_name = get_parser(nn_archive)
input_shape = get_input_shape(nn_archive)

if parser_name == "XFeatMonoParser":
Expand All @@ -31,58 +29,34 @@
with dai.Pipeline() as pipeline:
cam = pipeline.create(dai.node.Camera).build()

# YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser
if parser_name == "YOLO" or parser_name == "SSD":
network = pipeline.create(dai.node.DetectionNetwork).build(
cam.requestOutput(
input_shape, type=dai.ImgFrame.Type.BGR888p, fps=fps_limit
),
nn_archive,
)
parser_queue = network.out.createOutputQueue()
if "gray" in model:
image_type = dai.ImgFrame.Type.GRAY8
else:
image_type = dai.ImgFrame.Type.BGR888p
if "gray" in model_version_slug:
image_type = dai.ImgFrame.Type.GRAY8

if input_shape[0] < 128 or input_shape[1] < 128:
print(
"Input shape is too small so we are requesting a larger image and resizing it."
)
print(
"During visualization we resize small image back to large, so image quality is lower."
)
manip = pipeline.create(dai.node.ImageManip)
manip.initialConfig.setResize(input_shape)
large_input_shape = (input_shape[0] * 4, input_shape[1] * 4)
cam.requestOutput(large_input_shape, type=image_type, fps=fps_limit).link(
manip.inputImage
)
network = pipeline.create(dai.node.NeuralNetwork).build(
manip.out, nn_archive
)
else:
network = pipeline.create(dai.node.NeuralNetwork).build(
cam.requestOutput(input_shape, type=image_type, fps=fps_limit),
nn_archive,
)

parsers = pipeline.create(ParserGenerator).build(nn_archive)

if len(parsers) == 0:
raise ValueError("No parsers were generated.")

if len(parsers) > 1:
raise ValueError("Only models with one parser are supported.")

parser = parsers[0]

# Linking
network.out.link(parser.input)
if input_shape[0] < 128 or input_shape[1] < 128:
print(
"Input shape is too small so we are requesting a larger image and resizing it."
)
print(
"During visualization we resize small image back to large, so image quality is lower."
)
manip = pipeline.create(dai.node.ImageManip)
manip.initialConfig.setResize(input_shape)
large_input_shape = (input_shape[0] * 4, input_shape[1] * 4)
cam.requestOutput(large_input_shape, type=image_type, fps=fps_limit).link(
manip.inputImage
)
nn = pipeline.create(ParsingNeuralNetwork).build(manip.out, model)
else:
nn = pipeline.create(ParsingNeuralNetwork).build(
cam.requestOutput(input_shape, type=image_type, fps=fps_limit),
model,
)

parser_queue = parser.out.createOutputQueue()
parser_queue = nn.out.createOutputQueue()

camera_queue = network.passthrough.createOutputQueue()
camera_queue = nn.passthrough.createOutputQueue()

pipeline.start()

Expand Down
37 changes: 3 additions & 34 deletions examples/utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
from typing import Tuple


def initialize_argparser():
Expand All @@ -11,9 +10,9 @@ def initialize_argparser():
Currently, only RVC2 devices are supported. If using OAK-D Lite, please set the FPS limit to 28."

parser.add_argument(
"-s",
"--model_slug",
help="slug of the model in HubAI.",
"-m",
"--model",
help="HubAI model slug.",
required=True,
type=str,
)
Expand All @@ -30,33 +29,3 @@ def initialize_argparser():
args = parser.parse_args()

return parser, args


def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]:
"""Parse the model slug from the arguments.

Returns the model slug and model version slug.
"""
model_slug = args.model_slug

# parse the model slug
if ":" not in model_slug:
raise NameError(
"Please provide the model slug in the format of 'model_slug:model_version_slug'"
)

model_slug_parts = model_slug.split(":")
model_slug = model_slug_parts[0]
model_version_slug = model_slug_parts[1]

return model_slug, model_version_slug


def parse_fps_limit(args: argparse.Namespace) -> int:
"""Parse the FPS limit from the arguments.

Returns the FPS limit.
"""
fps_limit = args.fps_limit

return fps_limit
13 changes: 4 additions & 9 deletions examples/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
from depthai_nodes.ml.parsers import *


def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive:
"""Get the model from the HubAI and return the NN archive."""
print(
f"Downloading model {model_slug} with version {model_version_slug} from HubAI..."
)
modelDescription = dai.NNModelDescription(
modelSlug=model_slug, modelVersionSlug=model_version_slug, platform="RVC2"
)
def get_nn_archive_from_hub(model: str) -> dai.NNArchive:
"""Get NN archive from the HubAI based on model slug."""
print(f"Downloading model {model} from HubAI...")
modelDescription = dai.NNModelDescription(model=model, platform="RVC2")
archivePath = dai.getModelFromZoo(modelDescription)
print("Download successful!")
nn_archive = dai.NNArchive(archivePath)

return nn_archive


Expand Down
Loading