Skip to content

Commit e9168be

Browse files
committed
nms_thr and score_thr as optional parameters when calling rtmo
1 parent 8031578 commit e9168be

File tree

1 file changed

+13
-4
lines changed
  • rtmlib/tools/pose_estimation

1 file changed

+13
-4
lines changed

rtmlib/tools/pose_estimation/rtmo.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,25 @@ def __init__(self,
1515
model_input_size: tuple = (640, 640),
1616
mean: tuple = None,
1717
std: tuple = None,
18+
nms_thr: float = 0.45,
19+
score_thr: float = 0.7,
1820
to_openpose: bool = False,
1921
backend: str = 'onnxruntime',
2022
device: str = 'cpu'):
2123
super().__init__(onnx_model, model_input_size, mean, std, backend,
2224
device)
2325
self.to_openpose = to_openpose
26+
self.nms_thr = nms_thr
27+
self.score_thr = score_thr
28+
29+
def __call__(self, image: np.ndarray, nms_thr: float = None, score_thr: float = None):
30+
nms_thr = nms_thr if nms_thr is not None else self.nms_thr
31+
score_thr = score_thr if score_thr is not None else self.score_thr
2432

25-
def __call__(self, image: np.ndarray):
2633
image, ratio = self.preprocess(image)
2734
outputs = self.inference(image)
2835

29-
keypoints, scores = self.postprocess(outputs, ratio)
36+
keypoints, scores = self.postprocess(outputs, ratio, nms_thr, score_thr)
3037

3138
if self.to_openpose:
3239
keypoints, scores = convert_coco_to_openpose(keypoints, scores)
@@ -74,6 +81,8 @@ def postprocess(
7481
self,
7582
outputs: List[np.ndarray],
7683
ratio: float = 1.,
84+
nms_thr: float = None,
85+
score_thr: float = None,
7786
) -> Tuple[np.ndarray, np.ndarray]:
7887
"""Do postprocessing for RTMO model inference.
7988
@@ -97,8 +106,8 @@ def postprocess(
97106
# apply nms
98107
dets, keep = multiclass_nms(final_boxes,
99108
final_scores[:, np.newaxis],
100-
nms_thr=0.45,
101-
score_thr=0.7)
109+
nms_thr=nms_thr,
110+
score_thr=score_thr)
102111
if keep is not None:
103112
keypoints = keypoints[keep]
104113
scores = scores[keep]

0 commit comments

Comments
 (0)