Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv8 kpts and seg example. #48

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SCRFDParser,
SegmentationParser,
XFeatParser,
YOLOExtendedParser,
)


Expand Down Expand Up @@ -86,6 +87,17 @@ def setup_xfeat_parser(parser: XFeatParser, params: dict):
)


def setup_yolo_extended_parser(parser: YOLOExtendedParser, params: dict):
"""Setup the YOLO parser with the required metadata."""
try:
n_classes = params["n_classes"]
parser.setNumClasses(n_classes)
except Exception:
print(
"This NN archive does not have required metadata for YOLOExtendedParser. Skipping setup..."
)


def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_name: str):
"""Setup the parser with the NN archive."""

Expand All @@ -105,3 +117,5 @@ def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_nam
setup_monocular_depth_parser(parser, extraParams)
elif parser_name == "XFeatParser":
setup_xfeat_parser(parser, extraParams)
elif parser_name == "YOLOExtendedParser":
setup_yolo_extended_parser(parser, extraParams)
27 changes: 27 additions & 0 deletions examples/visualization/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,30 @@ def get_ewasr_colors():
(0, 0, 0), # class 2 - green
]
return colors


def get_yolo_colors():
colors = [
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
]

return colors
79 changes: 77 additions & 2 deletions examples/visualization/detection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import cv2
import depthai as dai
import numpy as np

from depthai_nodes.ml.messages import Lines
from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines

from .messages import parse_detection_message, parse_line_detection_message
from .colors import get_yolo_colors
from .messages import (
parse_detection_message,
parse_line_detection_message,
parse_yolo_kpts_message,
)


def visualize_detections(
Expand Down Expand Up @@ -100,3 +106,72 @@ def visualize_line_detections(frame: dai.ImgFrame, message: Lines, extraParams:
return True

return False


def visualize_yolo_extended(
frame: dai.ImgFrame, message: ImgDetectionsExtended, extraParams: dict
):
"""Visualizes the YOLO pose detections or instance segmentation on the frame."""
detections = parse_yolo_kpts_message(message)

colors = get_yolo_colors()
classes = extraParams.get("classes", None)
if classes is None:
raise ValueError("Classes are required for visualization.")
task = extraParams.get("n_keypoints", None)
if task is None:
task = "segmentation"
else:
task = "keypoints"

overlay = np.zeros_like(frame)

for detection in detections:
xmin, ymin, xmax, ymax = (
detection.xmin,
detection.ymin,
detection.xmax,
detection.ymax,
)
cv2.rectangle(
frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 2
)
cv2.putText(
frame,
f"{detection.confidence * 100:.2f}%",
(int(xmin) + 10, int(ymin) + 20),
cv2.FONT_HERSHEY_TRIPLEX,
0.5,
255,
)
cv2.putText(
frame,
f"{classes[detection.label]}",
(int(xmin) + 10, int(ymin) + 40),
cv2.FONT_HERSHEY_TRIPLEX,
0.5,
255,
)

if task == "keypoints":
keypoints = detection.keypoints
for keypoint in keypoints:
x, y, visibility = keypoint[0], keypoint[1], keypoint[2]
if visibility > 0.8:
cv2.circle(frame, (int(x), int(y)), 2, (0, 255, 0), -1)
else:
mask = detection.mask
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]))
mask = detection.mask
mask = cv2.resize(mask, (512, 288))
label = detection.label
overlay[mask > 0] = colors[(label % len(colors))]

if task == "segmentation":
frame = cv2.addWeighted(overlay, 0.8, frame, 0.5, 0, None)
cv2.imshow("YOLO Pose Estimation", frame)
if cv2.waitKey(1) == ord("q"):
cv2.destroyAllWindows()
return True

return False
7 changes: 6 additions & 1 deletion examples/visualization/mapping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .classification import visualize_age_gender, visualize_classification
from .detection import visualize_detections, visualize_line_detections
from .detection import (
visualize_detections,
visualize_line_detections,
visualize_yolo_extended,
)
from .image import visualize_image
from .keypoints import visualize_keypoints
from .segmentation import visualize_segmentation
Expand All @@ -20,4 +24,5 @@
"ImageOutputParser": visualize_image,
"MonocularDepthParser": visualize_image,
"AgeGenderParser": visualize_age_gender,
"YOLOExtendedParser": visualize_yolo_extended,
}
14 changes: 13 additions & 1 deletion examples/visualization/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import depthai as dai

from depthai_nodes.ml.messages import AgeGender, Classifications, Keypoints, Lines
from depthai_nodes.ml.messages import (
AgeGender,
Classifications,
ImgDetectionsExtended,
Keypoints,
Lines,
)


def parse_detection_message(message: dai.ImgDetections):
Expand Down Expand Up @@ -50,3 +56,9 @@ def parser_age_gender_message(message: AgeGender):
gender_classes = gender.classes

return age, gender_classes, gender_scores


def parse_yolo_kpts_message(message: ImgDetectionsExtended):
"""Parses the yolo keypoints message and returns the keypoints."""
detections = message.detections
return detections
Loading