Skip to content

Commit 0b41ff0

Browse files
ezyangpmeier
andauthored
Meta implementation for nms (#7944)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent d84aaae commit 0b41ff0

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchvision/_meta_registrations.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22

33
import torch
4+
import torch._custom_ops
45
import torch.library
56

67
# Ensure that torch.ops.torchvision is visible
@@ -48,3 +49,17 @@ def meta_roi_align_backward(
4849
),
4950
)
5051
return grad.new_empty((batch_size, channels, height, width))
52+
53+
54+
@torch._custom_ops.impl_abstract("torchvision::nms")
55+
def meta_nms(dets, scores, iou_threshold):
56+
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
57+
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
58+
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
59+
torch._check(
60+
dets.size(0) == scores.size(0),
61+
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
62+
)
63+
ctx = torch._custom_ops.get_ctx()
64+
num_to_keep = ctx.create_unbacked_symint()
65+
return dets.new_empty(num_to_keep, dtype=torch.long)

0 commit comments

Comments
 (0)