-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved SCRFD decoding. #21
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
import cv2 | ||
import depthai as dai | ||
import numpy as np | ||
|
||
from ..messages.creators import create_detection_message | ||
from .utils.scrfd import decode_scrfd | ||
|
||
|
||
class SCRFDParser(dai.node.ThreadedHostNode): | ||
|
@@ -20,6 +20,12 @@ class SCRFDParser(dai.node.ThreadedHostNode): | |
Non-maximum suppression threshold. | ||
top_k : int | ||
Maximum number of detections to keep. | ||
feat_stride_fpn : tuple | ||
Tuple of the feature strides. | ||
num_anchors : int | ||
Number of anchors. | ||
input_size : tuple | ||
Input size of the model. | ||
|
||
Output Message/s | ||
---------------- | ||
|
@@ -28,7 +34,15 @@ class SCRFDParser(dai.node.ThreadedHostNode): | |
**Description**: ImgDetections message containing bounding boxes, labels, and confidence scores of detected faces. | ||
""" | ||
|
||
def __init__(self, score_threshold=0.5, nms_threshold=0.5, top_k=100): | ||
def __init__( | ||
self, | ||
score_threshold=0.5, | ||
nms_threshold=0.5, | ||
top_k=100, | ||
input_size=(640, 640), | ||
feat_stride_fpn=(8, 16, 32), | ||
num_anchors=2, | ||
): | ||
"""Initializes the SCRFDParser node. | ||
|
||
@param score_threshold: Confidence score threshold for detected faces. | ||
|
@@ -37,6 +51,12 @@ def __init__(self, score_threshold=0.5, nms_threshold=0.5, top_k=100): | |
@type nms_threshold: float | ||
@param top_k: Maximum number of detections to keep. | ||
@type top_k: int | ||
@param feat_stride_fpn: List of the feature strides. | ||
@type feat_stride_fpn: tuple | ||
@param num_anchors: Number of anchors. | ||
@type num_anchors: int | ||
@param input_size: Input size of the model. | ||
@type input_size: tuple | ||
""" | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
|
@@ -46,6 +66,10 @@ def __init__(self, score_threshold=0.5, nms_threshold=0.5, top_k=100): | |
self.nms_threshold = nms_threshold | ||
self.top_k = top_k | ||
|
||
self.feat_stride_fpn = feat_stride_fpn | ||
self.num_anchors = num_anchors | ||
self.input_size = input_size | ||
|
||
def setConfidenceThreshold(self, threshold): | ||
"""Sets the confidence score threshold for detected faces. | ||
|
||
|
@@ -70,108 +94,100 @@ def setTopK(self, top_k): | |
""" | ||
self.top_k = top_k | ||
|
||
def setFeatStrideFPN(self, feat_stride_fpn): | ||
"""Sets the feature stride of the FPN. | ||
|
||
@param feat_stride_fpn: Feature stride of the FPN. | ||
@type feat_stride_fpn: list | ||
""" | ||
self.feat_stride_fpn = feat_stride_fpn | ||
|
||
def setInputSize(self, input_size): | ||
"""Sets the input size of the model. | ||
|
||
@param input_size: Input size of the model. | ||
@type input_size: list | ||
""" | ||
self.input_size = input_size | ||
|
||
def setNumAnchors(self, num_anchors): | ||
"""Sets the number of anchors. | ||
|
||
@param num_anchors: Number of anchors. | ||
@type num_anchors: int | ||
""" | ||
self.num_anchors = num_anchors | ||
|
||
def run(self): | ||
while self.isRunning(): | ||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException: | ||
break # Pipeline was stopped | ||
|
||
score_8 = output.getTensor("score_8").flatten().astype(np.float32) | ||
score_16 = output.getTensor("score_16").flatten().astype(np.float32) | ||
score_32 = output.getTensor("score_32").flatten().astype(np.float32) | ||
score_8 = ( | ||
output.getTensor("score_8", dequantize=True) | ||
.flatten() | ||
.astype(np.float32) | ||
) | ||
score_16 = ( | ||
output.getTensor("score_16", dequantize=True) | ||
.flatten() | ||
.astype(np.float32) | ||
) | ||
score_32 = ( | ||
output.getTensor("score_32", dequantize=True) | ||
.flatten() | ||
.astype(np.float32) | ||
) | ||
bbox_8 = ( | ||
output.getTensor("bbox_8").reshape(len(score_8), 4).astype(np.float32) | ||
output.getTensor("bbox_8", dequantize=True) | ||
.reshape(len(score_8), 4) | ||
.astype(np.float32) | ||
) | ||
bbox_16 = ( | ||
output.getTensor("bbox_16").reshape(len(score_16), 4).astype(np.float32) | ||
output.getTensor("bbox_16", dequantize=True) | ||
.reshape(len(score_16), 4) | ||
.astype(np.float32) | ||
) | ||
bbox_32 = ( | ||
output.getTensor("bbox_32").reshape(len(score_32), 4).astype(np.float32) | ||
output.getTensor("bbox_32", dequantize=True) | ||
.reshape(len(score_32), 4) | ||
.astype(np.float32) | ||
) | ||
kps_8 = ( | ||
output.getTensor("kps_8").reshape(len(score_8), 5, 2).astype(np.float32) | ||
output.getTensor("kps_8", dequantize=True) | ||
.reshape(len(score_8), 10) | ||
.astype(np.float32) | ||
) | ||
kps_16 = ( | ||
output.getTensor("kps_16") | ||
.reshape(len(score_16), 5, 2) | ||
output.getTensor("kps_16", dequantize=True) | ||
.reshape(len(score_16), 10) | ||
.astype(np.float32) | ||
) | ||
kps_32 = ( | ||
output.getTensor("kps_32") | ||
.reshape(len(score_32), 5, 2) | ||
output.getTensor("kps_32", dequantize=True) | ||
.reshape(len(score_32), 10) | ||
.astype(np.float32) | ||
) | ||
|
||
bboxes = [] | ||
keypoints = [] | ||
|
||
for i in range(len(score_8)): | ||
y = int(np.floor(i / 80)) * 4 | ||
x = (i % 160) * 4 | ||
bbox = bbox_8[i] | ||
xmin = int(x - bbox[0] * 8) | ||
ymin = int(y - bbox[1] * 8) | ||
xmax = int(x + bbox[2] * 8) | ||
ymax = int(y + bbox[3] * 8) | ||
kps = kps_8[i] | ||
kps_batch = [] | ||
for kp in kps: | ||
kpx = int(x + kp[0] * 8) | ||
kpy = int(y + kp[1] * 8) | ||
kps_batch.append([kpx, kpy]) | ||
keypoints.append(kps_batch) | ||
bbox = [xmin, ymin, xmax, ymax] | ||
bboxes.append(bbox) | ||
|
||
for i in range(len(score_16)): | ||
y = int(np.floor(i / 40)) * 8 | ||
x = (i % 80) * 8 | ||
bbox = bbox_16[i] | ||
xmin = int(x - bbox[0] * 16) | ||
ymin = int(y - bbox[1] * 16) | ||
xmax = int(x + bbox[2] * 16) | ||
ymax = int(y + bbox[3] * 16) | ||
kps = kps_16[i] | ||
kps_batch = [] | ||
for kp in kps: | ||
kpx = int(x + kp[0] * 16) | ||
kpy = int(y + kp[1] * 16) | ||
kps_batch.append([kpx, kpy]) | ||
keypoints.append(kps_batch) | ||
bbox = [xmin, ymin, xmax, ymax] | ||
bboxes.append(bbox) | ||
|
||
for i in range(len(score_32)): | ||
y = int(np.floor(i / 20)) * 16 | ||
x = (i % 40) * 16 | ||
bbox = bbox_32[i] | ||
xmin = int(x - bbox[0] * 32) | ||
ymin = int(y - bbox[1] * 32) | ||
xmax = int(x + bbox[2] * 32) | ||
ymax = int(y + bbox[3] * 32) | ||
kps = kps_32[i] | ||
kps_batch = [] | ||
for kp in kps: | ||
kpx = int(x + kp[0] * 32) | ||
kpy = int(y + kp[1] * 32) | ||
kps_batch.append([kpx, kpy]) | ||
keypoints.append(kps_batch) | ||
bbox = [xmin, ymin, xmax, ymax] | ||
bboxes.append(bbox) | ||
|
||
scores = np.concatenate([score_8, score_16, score_32]) | ||
indices = cv2.dnn.NMSBoxes( | ||
bboxes, | ||
list(scores), | ||
self.score_threshold, | ||
self.nms_threshold, | ||
top_k=self.top_k, | ||
bboxes_concatenated = [bbox_8, bbox_16, bbox_32] | ||
scores_concatenated = [score_8, score_16, score_32] | ||
kps_concatenated = [kps_8, kps_16, kps_32] | ||
|
||
bboxes, scores, keypoints = decode_scrfd( | ||
bboxes_concatenated=bboxes_concatenated, | ||
scores_concatenated=scores_concatenated, | ||
kps_concatenated=kps_concatenated, | ||
feat_stride_fpn=self.feat_stride_fpn, | ||
input_size=self.input_size, | ||
num_anchors=self.num_anchors, | ||
score_threshold=self.score_threshold, | ||
nms_threshold=self.nms_threshold, | ||
) | ||
detection_msg = create_detection_message( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In line to PR #19 we should also add timestamp to the message. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, added it when you were reviewing. Hmm I was thinking that it would be best to add it to the creator functions but we need |
||
bboxes, scores, None, keypoints.tolist() | ||
) | ||
bboxes = np.array(bboxes)[indices] | ||
keypoints = np.array(keypoints)[indices] | ||
scores = scores[indices] | ||
|
||
detection_msg = create_detection_message(bboxes, scores, None, None) | ||
detection_msg.setTimestamp(output.getTimestamp()) | ||
|
||
self.out.send(detection_msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Is there a better name for a variable than
pointcloud
. As I understandkeypoints
are a list of (n,2) keypoints for specific object? Maybe:Feel free to suggest other names though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah good point. I agree with the proposed naming. Added in a6b3b02.