Skip to content

Commit 501a2c9

Browse files
Add rotated bounding box formats (#8841)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent f709766 commit 501a2c9

File tree

7 files changed

+401
-60
lines changed

7 files changed

+401
-60
lines changed

test/common_utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def sample_position(values, max_value):
423423
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
424424
y = sample_position(h, canvas_size[0])
425425
x = sample_position(w, canvas_size[1])
426+
r = -360 * torch.rand((num_boxes,)) + 180
426427

427428
if format is tv_tensors.BoundingBoxFormat.XYWH:
428429
parts = (x, y, w, h)
@@ -435,6 +436,23 @@ def sample_position(values, max_value):
435436
cx = x + w / 2
436437
cy = y + h / 2
437438
parts = (cx, cy, w, h)
439+
elif format is tv_tensors.BoundingBoxFormat.XYWHR:
440+
parts = (x, y, w, h, r)
441+
elif format is tv_tensors.BoundingBoxFormat.CXCYWHR:
442+
cx = x + w / 2
443+
cy = y + h / 2
444+
parts = (cx, cy, w, h, r)
445+
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
446+
r_rad = r * torch.pi / 180.0
447+
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
448+
x1, y1 = x, y
449+
x3 = x1 + w * cos
450+
y3 = y1 - w * sin
451+
x2 = x3 + h * sin
452+
y2 = y3 + h * cos
453+
x4 = x1 + h * sin
454+
y4 = y1 + h * cos
455+
parts = (x1, y1, x3, y3, x2, y2, x4, y4)
438456
else:
439457
raise ValueError(f"Format {format} is not supported")
440458

test/test_ops.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -1339,8 +1339,61 @@ def test_bbox_xywh_cxcywh(self):
13391339
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
13401340
assert_equal(box_xywh, box_tensor)
13411341

1342-
@pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
1343-
@pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1342+
def test_bbox_xywhr_cxcywhr(self):
1343+
box_tensor = torch.tensor(
1344+
[
1345+
[0, 0, 100, 100, 0],
1346+
[0, 0, 0, 0, 0],
1347+
[10, 15, 20, 20, 0],
1348+
[23, 35, 70, 60, 0],
1349+
[4, 2, 4, 2, 0],
1350+
[5, 5, 4, 2, 90],
1351+
[8, 4, 4, 2, 180],
1352+
[7, 1, 4, 2, -90],
1353+
],
1354+
dtype=torch.float,
1355+
)
1356+
1357+
exp_cxcywhr = torch.tensor(
1358+
[
1359+
[50, 50, 100, 100, 0],
1360+
[0, 0, 0, 0, 0],
1361+
[20, 25, 20, 20, 0],
1362+
[58, 65, 70, 60, 0],
1363+
[6, 3, 4, 2, 0],
1364+
[6, 3, 4, 2, 90],
1365+
[6, 3, 4, 2, 180],
1366+
[6, 3, 4, 2, -90],
1367+
],
1368+
dtype=torch.float,
1369+
)
1370+
1371+
assert exp_cxcywhr.size() == torch.Size([8, 5])
1372+
box_cxcywhr = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="cxcywhr")
1373+
torch.testing.assert_close(box_cxcywhr, exp_cxcywhr)
1374+
1375+
# Reverse conversion
1376+
box_xywhr = ops.box_convert(box_cxcywhr, in_fmt="cxcywhr", out_fmt="xywhr")
1377+
torch.testing.assert_close(box_xywhr, box_tensor)
1378+
1379+
def test_bbox_cxcywhr_to_xyxyxyxy(self):
1380+
box_tensor = torch.tensor([[5, 3, 4, 2, 90]], dtype=torch.float)
1381+
exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float)
1382+
1383+
assert exp_xyxyxyxy.size() == torch.Size([1, 8])
1384+
box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="cxcywhr", out_fmt="xyxyxyxy")
1385+
torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy)
1386+
1387+
def test_bbox_xywhr_to_xyxyxyxy(self):
1388+
box_tensor = torch.tensor([[4, 5, 4, 2, 90]], dtype=torch.float)
1389+
exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float)
1390+
1391+
assert exp_xyxyxyxy.size() == torch.Size([1, 8])
1392+
box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="xyxyxyxy")
1393+
torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy)
1394+
1395+
@pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh", "xwyhr", "cxwyhr", "xxxxyyyy"])
1396+
@pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy", "xwcxr", "xhwcyr", "xyxyxxyy"])
13441397
def test_bbox_invalid(self, inv_infmt, inv_outfmt):
13451398
box_tensor = torch.tensor(
13461399
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float

test/test_transforms_v2.py

+45-31
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
5454

5555

56+
# While we are working on adjusting transform functions
57+
# for rotated and oriented bounding boxes formats,
58+
# we limit the perimeter of tests to formats
59+
# for which transform functions are already implemented.
60+
# In the future, this global variable will be replaced with `list(tv_tensors.BoundingBoxFormat)`
61+
# to support all available formats.
62+
SUPPORTED_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYXY", "XYWH", "CXCYWH"]]
63+
NEW_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYWHR", "CXCYWHR", "XYXYXYXY"]]
64+
5665
# turns all warnings into errors for this module
5766
pytestmark = [pytest.mark.filterwarnings("error")]
5867

@@ -626,7 +635,7 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype,
626635
check_scripted_vs_eager=not isinstance(size, int),
627636
)
628637

629-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
638+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
630639
@pytest.mark.parametrize("size", OUTPUT_SIZES)
631640
@pytest.mark.parametrize("use_max_size", [True, False])
632641
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@@ -757,7 +766,7 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
757766
new_canvas_size=(new_height, new_width),
758767
)
759768

760-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
769+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
761770
@pytest.mark.parametrize("size", OUTPUT_SIZES)
762771
@pytest.mark.parametrize("use_max_size", [True, False])
763772
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
@@ -1003,7 +1012,7 @@ class TestHorizontalFlip:
10031012
def test_kernel_image(self, dtype, device):
10041013
check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
10051014

1006-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1015+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
10071016
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
10081017
@pytest.mark.parametrize("device", cpu_and_cuda())
10091018
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -1072,7 +1081,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
10721081

10731082
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
10741083

1075-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1084+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
10761085
@pytest.mark.parametrize(
10771086
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
10781087
)
@@ -1169,7 +1178,7 @@ def test_kernel_image(self, param, value, dtype, device):
11691178
shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
11701179
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
11711180
)
1172-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1181+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
11731182
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
11741183
@pytest.mark.parametrize("device", cpu_and_cuda())
11751184
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
@@ -1318,7 +1327,7 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate,
13181327
),
13191328
)
13201329

1321-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1330+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
13221331
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
13231332
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
13241333
@pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
@@ -1346,7 +1355,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, translate, s
13461355

13471356
torch.testing.assert_close(actual, expected)
13481357

1349-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1358+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
13501359
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
13511360
@pytest.mark.parametrize("seed", list(range(5)))
13521361
def test_transform_bounding_boxes_correctness(self, format, center, seed):
@@ -1453,7 +1462,7 @@ class TestVerticalFlip:
14531462
def test_kernel_image(self, dtype, device):
14541463
check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
14551464

1456-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1465+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
14571466
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
14581467
@pytest.mark.parametrize("device", cpu_and_cuda())
14591468
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -1520,7 +1529,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
15201529

15211530
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
15221531

1523-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1532+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
15241533
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
15251534
def test_bounding_boxes_correctness(self, format, fn):
15261535
bounding_boxes = make_bounding_boxes(format=format)
@@ -1589,7 +1598,7 @@ def test_kernel_image(self, param, value, dtype, device):
15891598
expand=[False, True],
15901599
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
15911600
)
1592-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1601+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
15931602
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
15941603
@pytest.mark.parametrize("device", cpu_and_cuda())
15951604
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
@@ -1760,7 +1769,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
17601769
bounding_boxes
17611770
)
17621771

1763-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1772+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
17641773
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
17651774
@pytest.mark.parametrize("expand", [False, True])
17661775
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@@ -1773,7 +1782,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
17731782
torch.testing.assert_close(actual, expected)
17741783
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
17751784

1776-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1785+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
17771786
@pytest.mark.parametrize("expand", [False, True])
17781787
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
17791788
@pytest.mark.parametrize("seed", list(range(5)))
@@ -2694,7 +2703,7 @@ def test_kernel_image(self, param, value, dtype, device):
26942703
check_cuda_vs_cpu=dtype is not torch.float16,
26952704
)
26962705

2697-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2706+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
26982707
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
26992708
@pytest.mark.parametrize("device", cpu_and_cuda())
27002709
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -2821,7 +2830,7 @@ def test_kernel_image(self, kwargs, dtype, device):
28212830
check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs)
28222831

28232832
@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2824-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2833+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
28252834
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
28262835
@pytest.mark.parametrize("device", cpu_and_cuda())
28272836
def test_kernel_bounding_box(self, kwargs, format, dtype, device):
@@ -2971,7 +2980,7 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w
29712980
)
29722981

29732982
@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2974-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2983+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
29752984
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
29762985
@pytest.mark.parametrize("device", cpu_and_cuda())
29772986
def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
@@ -2984,7 +2993,7 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device
29842993
assert_equal(F.get_size(actual), F.get_size(expected))
29852994

29862995
@pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)])
2987-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2996+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
29882997
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
29892998
@pytest.mark.parametrize("device", cpu_and_cuda())
29902999
@pytest.mark.parametrize("seed", list(range(5)))
@@ -3507,7 +3516,8 @@ def test_aug_mix_severity_error(self, severity):
35073516

35083517

35093518
class TestConvertBoundingBoxFormat:
3510-
old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))
3519+
old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2))
3520+
old_new_formats += list(itertools.permutations(NEW_BOX_FORMATS, 2))
35113521

35123522
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
35133523
def test_kernel(self, old_format, new_format):
@@ -3518,7 +3528,7 @@ def test_kernel(self, old_format, new_format):
35183528
old_format=old_format,
35193529
)
35203530

3521-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3531+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
35223532
@pytest.mark.parametrize("inplace", [False, True])
35233533
def test_kernel_noop(self, format, inplace):
35243534
input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
@@ -3542,9 +3552,13 @@ def test_kernel_inplace(self, old_format, new_format):
35423552
output_inplace = F.convert_bounding_box_format(
35433553
input, old_format=old_format, new_format=new_format, inplace=True
35443554
)
3545-
assert output_inplace.data_ptr() == input.data_ptr()
3546-
assert output_inplace._version > input_version
3547-
assert output_inplace is input
3555+
if old_format != tv_tensors.BoundingBoxFormat.XYXYXYXY and new_format != tv_tensors.BoundingBoxFormat.XYXYXYXY:
3556+
# NOTE: BoundingBox format conversion from and to XYXYXYXY format
3557+
# cannot modify the input tensor inplace as it requires a dimension
3558+
# change.
3559+
assert output_inplace.data_ptr() == input.data_ptr()
3560+
assert output_inplace._version > input_version
3561+
assert output_inplace is input
35483562

35493563
assert_equal(output_inplace, output_out_of_place)
35503564

@@ -3563,7 +3577,7 @@ def test_transform(self, old_format, new_format, format_type):
35633577
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
35643578
def test_strings(self, old_format, new_format):
35653579
# Non-regression test for https://github.com/pytorch/vision/issues/8258
3566-
input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50))
3580+
input = make_bounding_boxes(format=old_format, canvas_size=(50, 50))
35673581
expected = self._reference_convert_bounding_box_format(input, new_format)
35683582

35693583
old_format = old_format.name
@@ -3728,7 +3742,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
37283742
new_canvas_size=size,
37293743
)
37303744

3731-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3745+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
37323746
def test_functional_bounding_boxes_correctness(self, format):
37333747
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
37343748

@@ -3796,7 +3810,7 @@ def test_kernel_image(self, param, value, dtype, device):
37963810
),
37973811
)
37983812

3799-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3813+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
38003814
def test_kernel_bounding_boxes(self, format):
38013815
bounding_boxes = make_bounding_boxes(format=format)
38023816
check_kernel(
@@ -3915,7 +3929,7 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding):
39153929
)
39163930

39173931
@pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
3918-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3932+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
39193933
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
39203934
@pytest.mark.parametrize("device", cpu_and_cuda())
39213935
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
@@ -3944,7 +3958,7 @@ def test_kernel_image(self, output_size, dtype, device):
39443958
)
39453959

39463960
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
3947-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3961+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
39483962
def test_kernel_bounding_boxes(self, output_size, format):
39493963
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
39503964
check_kernel(
@@ -4023,7 +4037,7 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
40234037
)
40244038

40254039
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
4026-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4040+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
40274041
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
40284042
@pytest.mark.parametrize("device", cpu_and_cuda())
40294043
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
@@ -4090,7 +4104,7 @@ def test_kernel_image_error(self):
40904104
coefficients=COEFFICIENTS,
40914105
start_end_points=START_END_POINTS,
40924106
)
4093-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4107+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
40944108
def test_kernel_bounding_boxes(self, param, value, format):
40954109
if param == "start_end_points":
40964110
kwargs = dict(zip(["startpoints", "endpoints"], value))
@@ -4266,7 +4280,7 @@ def perspective_bounding_boxes(bounding_boxes):
42664280
)
42674281

42684282
@pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
4269-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4283+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
42704284
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
42714285
@pytest.mark.parametrize("device", cpu_and_cuda())
42724286
def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):
@@ -4473,7 +4487,7 @@ def test_correctness_image(self, mean, std, dtype, fn):
44734487

44744488

44754489
class TestClampBoundingBoxes:
4476-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4490+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
44774491
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
44784492
@pytest.mark.parametrize("device", cpu_and_cuda())
44794493
def test_kernel(self, format, dtype, device):
@@ -4485,7 +4499,7 @@ def test_kernel(self, format, dtype, device):
44854499
canvas_size=bounding_boxes.canvas_size,
44864500
)
44874501

4488-
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4502+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
44894503
def test_functional(self, format):
44904504
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
44914505

0 commit comments

Comments
 (0)