Skip to content

Commit

Permalink
feat: adjust ParsingNeuralNetwork to build model directly from the mo…
Browse files Browse the repository at this point in the history
…del slug (#134)

* feat: adjust ParsingNeuralNetwork.build() to updates in DAI

* fix: pre-commit

* fix: slug splitting if instance hash is present

* fix: remove slug splitting

* feat: adjust examples

* fix: rename slug to model

* fix: pre-commit

* fix: add previous options to build ParsingNeuralNetwork from NNArchive or ModelDescription

* fix: pre-commit

* feat: remove usage of DetectionNetwork for YOLO and SSD examples

* fix: pre-commit
  • Loading branch information
jkbmrz authored Nov 21, 2024
1 parent 3e6f3d1 commit 0ccc611
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 113 deletions.
28 changes: 18 additions & 10 deletions depthai_nodes/parsing_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,16 @@ def build(
) -> "ParsingNeuralNetwork":
...

@overload
def build(
self, input: dai.Node.Output, nn_source: str, fps: int
) -> "ParsingNeuralNetwork":
...

def build(
self,
input: dai.Node.Output,
nn_source: Union[dai.NNModelDescription, dai.NNArchive],
nn_source: Union[dai.NNModelDescription, dai.NNArchive, str],
fps: int = None,
) -> "ParsingNeuralNetwork":
"""Builds the underlying NeuralNetwork node and creates parser nodes for each
Expand All @@ -63,9 +69,9 @@ def build(
@param input: Node's input. It is a linking point to which the NeuralNetwork is
linked. It accepts the output of a Camera node.
@type input: Node.Input
@param nn_source: NNModelDescription object containing the HubAI model
descriptors, or NNArchive object of the model.
@type nn_source: Union[dai.NNModelDescription, dai.NNArchive]
@param nn_source: NNModelDescription object containing the HubAI model descriptors, NNArchive object of the model, or HubAI model slug in form of <model_slug>:<model_version_slug> or <model_slug>:<model_version_slug>:<model_instance_hash>.
@type nn_source: Union[dai.NNModelDescription, dai.NNArchive, str]
@param fps: FPS limit for the model runtime.
@type fps: int
@return: Returns the ParsingNeuralNetwork object.
Expand All @@ -74,21 +80,23 @@ def build(
object.
"""

if isinstance(nn_source, dai.NNModelDescription):
platform = self.getParentPipeline().getDefaultDevice().getPlatformAsString()

if isinstance(nn_source, str):
nn_source = dai.NNModelDescription(nn_source)
if isinstance(nn_source, (dai.NNModelDescription, str)):
if not nn_source.platform:
nn_source.platform = (
self.getParentPipeline().getDefaultDevice().getPlatformAsString()
)
nn_source.platform = platform
self._nn_archive = dai.NNArchive(dai.getModelFromZoo(nn_source))
elif isinstance(nn_source, dai.NNArchive):
self._nn_archive = nn_source
else:
raise ValueError(
"nn_source must be either a NNModelDescription or NNArchive"
"nn_source must be either a NNModelDescription, NNArchive, or a string representing HubAI model slug."
)

kwargs = {"fps": fps} if fps else {}
self._nn.build(input, nn_source, **kwargs)
self._nn.build(input, self._nn_archive, **kwargs)

self._updateParsers(self._nn_archive)
return self
Expand Down
6 changes: 3 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ Prepare the model slug and run:

```
cd examples
python main.py -s <slug_of_your_model>
python main.py -m <slug_of_your_model>
```

For example:

```
python main.py -s yolov6-nano:coco-416x416
python main.py -m yolov6-nano:coco-416x416
```

Note that for running the examples you need RVC2 device connected.
Expand All @@ -29,7 +29,7 @@ If using OAK-D Lite, make sure to also set the FPS limit under 28.5.
For example:

```
python main.py -s yolov6-nano:coco-416x416 -fps 28
python main.py -m yolov6-nano:coco-416x416 -fps 28
```

Some models have small input sizes and requesting small image size from `Camera` is problematic so we request 4x bigger frame and resize it back down. During visualization image frame is resized back so some image quality is lost - only for visualization.
Expand Down
88 changes: 31 additions & 57 deletions examples/main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import depthai as dai
from utils.arguments import initialize_argparser, parse_fps_limit, parse_model_slug
from utils.model import get_input_shape, get_model_from_hub, get_parser
from utils.arguments import initialize_argparser
from utils.model import get_input_shape, get_nn_archive_from_hub, get_parser
from utils.xfeat import xfeat_mono, xfeat_stereo
from visualization.visualize import visualize

from depthai_nodes.parser_generator import ParserGenerator
from depthai_nodes import ParsingNeuralNetwork

# Initialize the argument parser
arg_parser, args = initialize_argparser()

# Parse the model slug
model_slug, model_version_slug = parse_model_slug(args)
fps_limit = parse_fps_limit(args)

# Get the model from the HubAI
nn_archive = get_model_from_hub(model_slug, model_version_slug)
# Parse the arguments
model = args.model
fps_limit = args.fps_limit

# Get the parser
parser_class, parser_name = get_parser(nn_archive)
nn_archive = get_nn_archive_from_hub(model)
_, parser_name = get_parser(nn_archive)
input_shape = get_input_shape(nn_archive)

if parser_name == "XFeatMonoParser":
Expand All @@ -31,58 +29,34 @@
with dai.Pipeline() as pipeline:
cam = pipeline.create(dai.node.Camera).build()

# YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser
if parser_name == "YOLO" or parser_name == "SSD":
network = pipeline.create(dai.node.DetectionNetwork).build(
cam.requestOutput(
input_shape, type=dai.ImgFrame.Type.BGR888p, fps=fps_limit
),
nn_archive,
)
parser_queue = network.out.createOutputQueue()
if "gray" in model:
image_type = dai.ImgFrame.Type.GRAY8
else:
image_type = dai.ImgFrame.Type.BGR888p
if "gray" in model_version_slug:
image_type = dai.ImgFrame.Type.GRAY8

if input_shape[0] < 128 or input_shape[1] < 128:
print(
"Input shape is too small so we are requesting a larger image and resizing it."
)
print(
"During visualization we resize small image back to large, so image quality is lower."
)
manip = pipeline.create(dai.node.ImageManip)
manip.initialConfig.setResize(input_shape)
large_input_shape = (input_shape[0] * 4, input_shape[1] * 4)
cam.requestOutput(large_input_shape, type=image_type, fps=fps_limit).link(
manip.inputImage
)
network = pipeline.create(dai.node.NeuralNetwork).build(
manip.out, nn_archive
)
else:
network = pipeline.create(dai.node.NeuralNetwork).build(
cam.requestOutput(input_shape, type=image_type, fps=fps_limit),
nn_archive,
)

parsers = pipeline.create(ParserGenerator).build(nn_archive)

if len(parsers) == 0:
raise ValueError("No parsers were generated.")

if len(parsers) > 1:
raise ValueError("Only models with one parser are supported.")

parser = parsers[0]

# Linking
network.out.link(parser.input)
if input_shape[0] < 128 or input_shape[1] < 128:
print(
"Input shape is too small so we are requesting a larger image and resizing it."
)
print(
"During visualization we resize small image back to large, so image quality is lower."
)
manip = pipeline.create(dai.node.ImageManip)
manip.initialConfig.setResize(input_shape)
large_input_shape = (input_shape[0] * 4, input_shape[1] * 4)
cam.requestOutput(large_input_shape, type=image_type, fps=fps_limit).link(
manip.inputImage
)
nn = pipeline.create(ParsingNeuralNetwork).build(manip.out, model)
else:
nn = pipeline.create(ParsingNeuralNetwork).build(
cam.requestOutput(input_shape, type=image_type, fps=fps_limit),
model,
)

parser_queue = parser.out.createOutputQueue()
parser_queue = nn.out.createOutputQueue()

camera_queue = network.passthrough.createOutputQueue()
camera_queue = nn.passthrough.createOutputQueue()

pipeline.start()

Expand Down
37 changes: 3 additions & 34 deletions examples/utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
from typing import Tuple


def initialize_argparser():
Expand All @@ -11,9 +10,9 @@ def initialize_argparser():
Currently, only RVC2 devices are supported. If using OAK-D Lite, please set the FPS limit to 28."

parser.add_argument(
"-s",
"--model_slug",
help="slug of the model in HubAI.",
"-m",
"--model",
help="HubAI model slug.",
required=True,
type=str,
)
Expand All @@ -30,33 +29,3 @@ def initialize_argparser():
args = parser.parse_args()

return parser, args


def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]:
"""Parse the model slug from the arguments.
Returns the model slug and model version slug.
"""
model_slug = args.model_slug

# parse the model slug
if ":" not in model_slug:
raise NameError(
"Please provide the model slug in the format of 'model_slug:model_version_slug'"
)

model_slug_parts = model_slug.split(":")
model_slug = model_slug_parts[0]
model_version_slug = model_slug_parts[1]

return model_slug, model_version_slug


def parse_fps_limit(args: argparse.Namespace) -> int:
"""Parse the FPS limit from the arguments.
Returns the FPS limit.
"""
fps_limit = args.fps_limit

return fps_limit
13 changes: 4 additions & 9 deletions examples/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
from depthai_nodes.ml.parsers import *


def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive:
"""Get the model from the HubAI and return the NN archive."""
print(
f"Downloading model {model_slug} with version {model_version_slug} from HubAI..."
)
modelDescription = dai.NNModelDescription(
modelSlug=model_slug, modelVersionSlug=model_version_slug, platform="RVC2"
)
def get_nn_archive_from_hub(model: str) -> dai.NNArchive:
"""Get NN archive from the HubAI based on model slug."""
print(f"Downloading model {model} from HubAI...")
modelDescription = dai.NNModelDescription(model=model, platform="RVC2")
archivePath = dai.getModelFromZoo(modelDescription)
print("Download successful!")
nn_archive = dai.NNArchive(archivePath)

return nn_archive


Expand Down

0 comments on commit 0ccc611

Please sign in to comment.