Skip to content

Commit

Permalink
Examples added.
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeroo committed Aug 30, 2024
1 parent afe4ea7 commit 1c353dd
Show file tree
Hide file tree
Showing 15 changed files with 677 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,5 @@ cython_debug/
#.idea/

.DS_Store

.depthai_cached_models/
77 changes: 77 additions & 0 deletions examples/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import depthai as dai
from utils.arguments import initialize_argparser, parse_model_slug
from utils.model import get_input_shape, get_model_from_hub, get_parser, setup_parser
from visualization.visualize import visualize

# Initialize the argument parser
arg_parser, args = initialize_argparser()

# Parse the model slug
model_slug, model_version_slug = parse_model_slug(args)

# Get the model from the HubAI
nn_archive = get_model_from_hub(model_slug, model_version_slug)

# Get the parser
parser_class, parser_name = get_parser(nn_archive)
input_shape = get_input_shape(nn_archive)

if parser_name == "XFeatParser":
raise NotImplementedError("XFeatParser is not supported in this script yet.")

# Create the pipeline
with dai.Pipeline() as pipeline:
cam = pipeline.create(dai.node.Camera).build()

# YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser
if parser_name == "YOLO" or parser_name == "SSD":
network = pipeline.create(dai.node.DetectionNetwork).build(
cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p), nn_archive
)
parser_queue = network.out.createOutputQueue()
else:
image_type = dai.ImgFrame.Type.BGR888p
if "gray" in model_version_slug:
image_type = dai.ImgFrame.Type.GRAY8

if input_shape[0] < 128 or input_shape[1] < 128:
print(
"Input shape is too small so we are requesting a larger image and resizing it."
)
print(
"During visualization we resize small image back to large, so image quality is lower."
)
manip = pipeline.create(dai.node.ImageManip)
manip.initialConfig.setResize(input_shape)
large_input_shape = (input_shape[0] * 4, input_shape[1] * 4)
cam.requestOutput(large_input_shape, type=image_type).link(manip.inputImage)
network = pipeline.create(dai.node.NeuralNetwork).build(
manip.out, nn_archive
)
else:
network = pipeline.create(dai.node.NeuralNetwork).build(
cam.requestOutput(input_shape, type=image_type), nn_archive
)

parser = pipeline.create(parser_class)
setup_parser(parser, nn_archive, parser_name)

# Linking
network.out.link(parser.input)

parser_queue = parser.out.createOutputQueue()

camera_queue = network.passthrough.createOutputQueue()

pipeline.start()

while pipeline.isRunning():
frame: dai.ImgFrame = camera_queue.get().getCvFrame()
message = parser_queue.get()

extraParams = (
nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams
)
if visualize(frame, message, parser_name, extraParams):
pipeline.stop()
break
Empty file added examples/utils/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions examples/utils/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse
from typing import Tuple


def initialize_argparser():
"""Initialize the argument parser for the script."""
parser = argparse.ArgumentParser()
parser.description = "General example script to run any model available in HubAI on DepthAI device. \
All you need is a model slug of the model and the script will download the model from HubAI and create \
the whole pipeline with visualizations. You also need a DepthAI device connected to your computer. \
Currently, only RVC2 devices are supported."

parser.add_argument(
"-s",
"--model_slug",
help="slug of the model in HubAI.",
required=True,
type=str,
)

args = parser.parse_args()

return parser, args


def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]:
"""Parse the model slug from the arguments.
Returns the model slug and model version slug.
"""
model_slug = args.model_slug

# parse the model slug
if ":" not in model_slug:
raise NameError(
"Please provide the model slug in the format of 'model_slug:model_version_slug'"
)

model_slug_parts = model_slug.split(":")
model_slug = model_slug_parts[0]
model_version_slug = model_slug_parts[1]

return model_slug, model_version_slug
191 changes: 191 additions & 0 deletions examples/utils/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from typing import List, Tuple

import depthai as dai

from depthai_nodes.ml.parsers import (
ClassificationParser,
KeypointParser,
MonocularDepthParser,
SCRFDParser,
SegmentationParser,
XFeatParser,
)


def get_model_from_hub(model_slug: str, model_version_slug: str) -> dai.NNArchive:
"""Get the model from the HubAI and return the NN archive."""
print(
f"Downloading model {model_slug} with version {model_version_slug} from HubAI..."
)
modelDescription = dai.NNModelDescription(
modelSlug=model_slug, modelVersionSlug=model_version_slug, platform="RVC2"
)
archivePath = dai.getModelFromZoo(modelDescription)
print("Download successful!")
nn_archive = dai.NNArchive(archivePath)

return nn_archive


def get_parser_from_archive(nn_archive: dai.NNArchive) -> str:
"""Get the required parser from the NN archive."""
try:
required_parser = nn_archive.getConfig().getConfigV1().model.heads[0].parser
except AttributeError:
print(
"This NN archive does not have a parser. Please use NN archives that have parsers."
)
exit(1)

print(f"Required parser: {required_parser}")

return required_parser


def get_parser(nn_archive: dai.NNArchive) -> Tuple[dai.ThreadedNode, str]:
"""Map the parser from the NN archive to the actual parser in depthai-nodes."""
required_parser = get_parser_from_archive(nn_archive)

if required_parser == "YOLO" or required_parser == "SSD":
return None, required_parser

parser = globals().get(required_parser, None)

if parser is None:
raise NameError(
f"Parser {required_parser} is not available in the depthai_nodes.ml.parsers module."
)

return parser, required_parser


def get_inputs_from_archive(nn_archive: dai.NNArchive) -> List:
"""Get all inputs from NN archive."""
try:
inputs = nn_archive.getConfig().getConfigV1().model.inputs
except AttributeError:
print(
"This NN archive does not have an input shape. Please use NN archives that have input shapes."
)
exit(1)

return inputs


def get_input_shape(nn_archive: dai.NNArchive) -> Tuple[int, int]:
"""Get the input shape of the model from the NN archive."""
inputs = get_inputs_from_archive(nn_archive)

if len(inputs) > 1:
raise ValueError(
"This model has more than one input. Currently, only models with one input are supported."
)

try:
input_shape = inputs[0].shape[2:][::-1]
except AttributeError:
print(
"This NN archive does not have an input shape. Please use NN archives that have input shapes."
)
exit(1)

print(f"Input shape: {input_shape}")

return input_shape


def setup_scrfd_parser(parser: SCRFDParser, params: dict):
"""Setup the SCRFD parser with the required metadata."""
try:
num_anchors = params["num_anchors"]
feat_stride_fpn = params["feat_stride_fpn"]
parser.setNumAnchors(num_anchors)
parser.setFeatStrideFPN(feat_stride_fpn)
except Exception:
print(
"This NN archive does not have required metadata for SCRFDParser. Skipping setup..."
)


def setup_segmentation_parser(parser: SegmentationParser, params: dict):
"""Setup the segmentation parser with the required metadata."""
try:
background_class = params["background_class"]
parser.setBackgroundClass(background_class)
except Exception:
print(
"This NN archive does not have required metadata for SegmentationParser. Skipping setup..."
)


def setup_keypoint_parser(parser: KeypointParser, params: dict):
"""Setup the keypoint parser with the required metadata."""
try:
num_keypoints = params["n_keypoints"]
scale_factor = params["scale_factor"]
parser.setNumKeypoints(num_keypoints)
parser.setScaleFactor(scale_factor)
except Exception:
print(
"This NN archive does not have required metadata for KeypointParser. Skipping setup..."
)


def setup_classification_parser(parser: ClassificationParser, params: dict):
"""Setup the classification parser with the required metadata."""
try:
classes = params["classes"]
is_softmax = params["is_softmax"]
parser.setClasses(classes)
parser.setSoftmax(is_softmax)
except Exception:
print(
"This NN archive does not have required metadata for ClassificationParser. Skipping setup..."
)


def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict):
"""Setup the monocular depth parser with the required metadata."""
try:
depth_type = params["depth_type"]
if depth_type == "relative":
parser.setRelativeDepthType()
else:
parser.setMetricDepthType()
except Exception:
print(
"This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..."
)


def setup_xfeat_parser(parser: XFeatParser, params: dict):
"""Setup the XFeat parser with the required metadata."""
try:
input_size = params["input_size"]
parser.setInputSize(input_size)
parser.setOriginalSize(input_size)
except Exception:
print(
"This NN archive does not have required metadata for XFeatParser. Skipping setup..."
)


def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str):
"""Setup the parser with the NN archive."""

extraParams = (
nn_archive.getConfig().getConfigV1().model.heads[0].metadata.extraParams
)

if parser_name == "SCRFDParser":
setup_scrfd_parser(parser, extraParams)
elif parser_name == "SegmentationParser":
setup_segmentation_parser(parser, extraParams)
elif parser_name == "KeypointParser":
setup_keypoint_parser(parser, extraParams)
elif parser_name == "ClassificationParser":
setup_classification_parser(parser, extraParams)
elif parser_name == "MonocularDepthParser":
setup_monocular_depth_parser(parser, extraParams)
elif parser_name == "XFeatParser":
setup_xfeat_parser(parser, extraParams)
Empty file.
57 changes: 57 additions & 0 deletions examples/visualization/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import cv2
import depthai as dai

from depthai_nodes.ml.messages import AgeGender, Classifications

from .messages import parse_classification_message, parser_age_gender_message


def visualize_classification(
frame: dai.ImgFrame, message: Classifications, extraParams: dict
):
"""Visualizes the classification on the frame."""
classes, scores = parse_classification_message(message)
classes = classes[:2]
scores = scores[:2]
if frame.shape[0] < 128:
frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2))
for i, (cls, score) in enumerate(zip(classes, scores)):
cv2.putText(
frame,
f"{cls}: {score:.2f}",
(10, 20 + 20 * i),
cv2.FONT_HERSHEY_TRIPLEX,
0.5,
255,
)

cv2.imshow("Classification", frame)
if cv2.waitKey(1) == ord("q"):
cv2.destroyAllWindows()
return True

return False


def visualize_age_gender(frame: dai.ImgFrame, message: AgeGender, extraParams: dict):
"""Visualizes the age and predicted gender on the frame."""
if frame.shape[0] < 128:
frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2))
age, gender_classes, gender_scores = parser_age_gender_message(message)
cv2.putText(frame, f"Age: {age}", (10, 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
for i, (cls, score) in enumerate(zip(gender_classes, gender_scores)):
cv2.putText(
frame,
f"{cls}: {score:.2f}",
(10, 40 + 20 * i),
cv2.FONT_HERSHEY_TRIPLEX,
0.5,
255,
)

cv2.imshow("Age-gender", frame)
if cv2.waitKey(1) == ord("q"):
cv2.destroyAllWindows()
return True

return False
Loading

0 comments on commit 1c353dd

Please sign in to comment.