Skip to content

Commit e442302

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] port sample input smoke test (#7962)
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Reviewed By: vmoens Differential Revision: D50789108 fbshipit-source-id: cb38c29356d6ffdba86819bc65b44d2081833d28
1 parent 31a0e2d commit e442302

4 files changed

+246
-311
lines changed

test/common_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def sample_position(values, max_value):
420420
dtype = dtype or torch.float32
421421

422422
num_objects = 1
423-
h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
423+
h, w = [torch.randint(1, s, (num_objects,)) for s in canvas_size]
424424
y = sample_position(h, canvas_size[0])
425425
x = sample_position(w, canvas_size[1])
426426

test/test_transforms_v2.py

+2-265
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import itertools
2-
import pathlib
3-
import pickle
42
import random
53

64
import numpy as np
@@ -11,22 +9,11 @@
119
import torchvision.transforms.v2 as transforms
1210

1311
from common_utils import assert_equal, cpu_and_cuda
14-
from torch.utils._pytree import tree_flatten, tree_unflatten
1512
from torchvision import tv_tensors
1613
from torchvision.ops.boxes import box_iou
1714
from torchvision.transforms.functional import to_pil_image
18-
from torchvision.transforms.v2 import functional as F
19-
from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
20-
from transforms_v2_legacy_utils import (
21-
make_bounding_boxes,
22-
make_detection_mask,
23-
make_image,
24-
make_images,
25-
make_multiple_bounding_boxes,
26-
make_segmentation_mask,
27-
make_video,
28-
make_videos,
29-
)
15+
from torchvision.transforms.v2._utils import is_pure_tensor
16+
from transforms_v2_legacy_utils import make_bounding_boxes, make_detection_mask, make_image, make_images, make_videos
3017

3118

3219
def make_vanilla_tensor_images(*args, **kwargs):
@@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs):
4128
yield to_pil_image(image)
4229

4330

44-
def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
45-
for bounding_boxes in make_multiple_bounding_boxes(*args, **kwargs):
46-
yield bounding_boxes.data
47-
48-
4931
def parametrize(transforms_with_inputs):
5032
return pytest.mark.parametrize(
5133
("transform", "input"),
@@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs):
6143
)
6244

6345

64-
def auto_augment_adapter(transform, input, device):
65-
adapted_input = {}
66-
image_or_video_found = False
67-
for key, value in input.items():
68-
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
69-
# AA transforms don't support bounding boxes or masks
70-
continue
71-
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
72-
if image_or_video_found:
73-
# AA transforms only support a single image or video
74-
continue
75-
image_or_video_found = True
76-
adapted_input[key] = value
77-
return adapted_input
78-
79-
80-
def linear_transformation_adapter(transform, input, device):
81-
flat_inputs = list(input.values())
82-
c, h, w = query_chw(
83-
[
84-
item
85-
for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
86-
if needs_transform
87-
]
88-
)
89-
num_elements = c * h * w
90-
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
91-
transform.mean_vector = torch.randn((num_elements,), device=device)
92-
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
93-
94-
95-
def normalize_adapter(transform, input, device):
96-
adapted_input = {}
97-
for key, value in input.items():
98-
if isinstance(value, PIL.Image.Image):
99-
# normalize doesn't support PIL images
100-
continue
101-
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
102-
# normalize doesn't support integer images
103-
value = F.to_dtype(value, torch.float32, scale=True)
104-
adapted_input[key] = value
105-
return adapted_input
106-
107-
108-
class TestSmoke:
109-
@pytest.mark.parametrize(
110-
("transform", "adapter"),
111-
[
112-
(transforms.RandomErasing(p=1.0), None),
113-
(transforms.AugMix(), auto_augment_adapter),
114-
(transforms.AutoAugment(), auto_augment_adapter),
115-
(transforms.RandAugment(), auto_augment_adapter),
116-
(transforms.TrivialAugmentWide(), auto_augment_adapter),
117-
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
118-
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
119-
(transforms.RandomAutocontrast(p=1.0), None),
120-
(transforms.RandomEqualize(p=1.0), None),
121-
(transforms.RandomInvert(p=1.0), None),
122-
(transforms.RandomChannelPermutation(), None),
123-
(transforms.RandomPosterize(bits=4, p=1.0), None),
124-
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
125-
(transforms.CenterCrop([16, 16]), None),
126-
(transforms.ElasticTransform(sigma=1.0), None),
127-
(transforms.Pad(4), None),
128-
(transforms.RandomAffine(degrees=30.0), None),
129-
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
130-
(transforms.RandomHorizontalFlip(p=1.0), None),
131-
(transforms.RandomPerspective(p=1.0), None),
132-
(transforms.RandomResize(min_size=10, max_size=20, antialias=True), None),
133-
(transforms.RandomResizedCrop([16, 16], antialias=True), None),
134-
(transforms.RandomRotation(degrees=30), None),
135-
(transforms.RandomShortestSize(min_size=10, antialias=True), None),
136-
(transforms.RandomVerticalFlip(p=1.0), None),
137-
(transforms.Resize([16, 16], antialias=True), None),
138-
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
139-
(transforms.ClampBoundingBoxes(), None),
140-
(transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
141-
(transforms.ConvertImageDtype(), None),
142-
(transforms.GaussianBlur(kernel_size=3), None),
143-
(
144-
transforms.LinearTransformation(
145-
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
146-
# because for we neither know the spatial size nor the device at this point
147-
transformation_matrix=torch.empty((1, 1)),
148-
mean_vector=torch.empty((1,)),
149-
),
150-
linear_transformation_adapter,
151-
),
152-
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
153-
(transforms.ToDtype(torch.float64), None),
154-
(transforms.UniformTemporalSubsample(num_samples=2), None),
155-
],
156-
ids=lambda transform: type(transform).__name__,
157-
)
158-
@pytest.mark.parametrize("container_type", [dict, list, tuple])
159-
@pytest.mark.parametrize(
160-
"image_or_video",
161-
[
162-
make_image(),
163-
make_video(),
164-
next(make_pil_images(color_spaces=["RGB"])),
165-
next(make_vanilla_tensor_images()),
166-
],
167-
)
168-
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
169-
@pytest.mark.parametrize("device", cpu_and_cuda())
170-
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
171-
transform = de_serialize(transform)
172-
173-
canvas_size = F.get_size(image_or_video)
174-
input = dict(
175-
image_or_video=image_or_video,
176-
image_tv_tensor=make_image(size=canvas_size),
177-
video_tv_tensor=make_video(size=canvas_size),
178-
image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
179-
bounding_boxes_xyxy=make_bounding_boxes(
180-
format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
181-
),
182-
bounding_boxes_xywh=make_bounding_boxes(
183-
format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
184-
),
185-
bounding_boxes_cxcywh=make_bounding_boxes(
186-
format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
187-
),
188-
bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
189-
[
190-
[0, 0, 0, 0], # no height or width
191-
[0, 0, 0, 1], # no height
192-
[0, 0, 1, 0], # no width
193-
[2, 0, 1, 1], # x1 > x2, y1 < y2
194-
[0, 2, 1, 1], # x1 < x2, y1 > y2
195-
[2, 2, 1, 1], # x1 > x2, y1 > y2
196-
],
197-
format=tv_tensors.BoundingBoxFormat.XYXY,
198-
canvas_size=canvas_size,
199-
),
200-
bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
201-
[
202-
[0, 0, 0, 0], # no height or width
203-
[0, 0, 0, 1], # no height
204-
[0, 0, 1, 0], # no width
205-
[0, 0, 1, -1], # negative height
206-
[0, 0, -1, 1], # negative width
207-
[0, 0, -1, -1], # negative height and width
208-
],
209-
format=tv_tensors.BoundingBoxFormat.XYWH,
210-
canvas_size=canvas_size,
211-
),
212-
bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
213-
[
214-
[0, 0, 0, 0], # no height or width
215-
[0, 0, 0, 1], # no height
216-
[0, 0, 1, 0], # no width
217-
[0, 0, 1, -1], # negative height
218-
[0, 0, -1, 1], # negative width
219-
[0, 0, -1, -1], # negative height and width
220-
],
221-
format=tv_tensors.BoundingBoxFormat.CXCYWH,
222-
canvas_size=canvas_size,
223-
),
224-
detection_mask=make_detection_mask(size=canvas_size),
225-
segmentation_mask=make_segmentation_mask(size=canvas_size),
226-
int=0,
227-
float=0.0,
228-
bool=True,
229-
none=None,
230-
str="str",
231-
path=pathlib.Path.cwd(),
232-
object=object(),
233-
tensor=torch.empty(5),
234-
array=np.empty(5),
235-
)
236-
if adapter is not None:
237-
input = adapter(transform, input, device)
238-
239-
if container_type in {tuple, list}:
240-
input = container_type(input.values())
241-
242-
input_flat, input_spec = tree_flatten(input)
243-
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
244-
input = tree_unflatten(input_flat, input_spec)
245-
246-
torch.manual_seed(0)
247-
output = transform(input)
248-
output_flat, output_spec = tree_flatten(output)
249-
250-
assert output_spec == input_spec
251-
252-
for output_item, input_item, should_be_transformed in zip(
253-
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
254-
):
255-
if should_be_transformed:
256-
assert type(output_item) is type(input_item)
257-
else:
258-
assert output_item is input_item
259-
260-
if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance(
261-
transform, transforms.ConvertBoundingBoxFormat
262-
):
263-
assert output_item.format == input_item.format
264-
265-
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
266-
# transform that does this), back into a valid one.
267-
# TODO: we should test that against all degenerate boxes above
268-
for format in list(tv_tensors.BoundingBoxFormat):
269-
sample = dict(
270-
boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
271-
labels=torch.tensor([3]),
272-
)
273-
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
274-
275-
27646
@pytest.mark.parametrize(
27747
"flat_inputs",
27848
itertools.permutations(
@@ -543,39 +313,6 @@ def test__get_params(self, min_size, max_size):
543313
assert shorter in min_size
544314

545315

546-
class TestLinearTransformation:
547-
def test_assertions(self):
548-
with pytest.raises(ValueError, match="transformation_matrix should be square"):
549-
transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5))
550-
551-
with pytest.raises(ValueError, match="mean_vector should have the same length"):
552-
transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5))
553-
554-
@pytest.mark.parametrize(
555-
"inpt",
556-
[
557-
122 * torch.ones(1, 3, 8, 8),
558-
122.0 * torch.ones(1, 3, 8, 8),
559-
tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)),
560-
PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
561-
],
562-
)
563-
def test__transform(self, inpt):
564-
565-
v = 121 * torch.ones(3 * 8 * 8)
566-
m = torch.ones(3 * 8 * 8, 3 * 8 * 8)
567-
transform = transforms.LinearTransformation(m, v)
568-
569-
if isinstance(inpt, PIL.Image.Image):
570-
with pytest.raises(TypeError, match="does not support PIL images"):
571-
transform(inpt)
572-
else:
573-
output = transform(inpt)
574-
assert isinstance(output, torch.Tensor)
575-
assert output.unique() == 3 * 8 * 8
576-
assert output.dtype == inpt.dtype
577-
578-
579316
class TestRandomResize:
580317
def test__get_params(self):
581318
min_size = 3

test/test_transforms_v2_consistency.py

-22
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,6 @@ def __init__(
7272
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
7373

7474
CONSISTENCY_CONFIGS = [
75-
*[
76-
ConsistencyConfig(
77-
v2_transforms.LinearTransformation,
78-
legacy_transforms.LinearTransformation,
79-
[
80-
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
81-
],
82-
# Make sure that the product of the height, width and number of channels matches the number of elements in
83-
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
84-
make_images_kwargs=dict(
85-
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
86-
),
87-
supports_pil=False,
88-
)
89-
for matrix_dtype, image_dtype in [
90-
(torch.float32, torch.float32),
91-
(torch.float64, torch.float64),
92-
(torch.float32, torch.uint8),
93-
(torch.float64, torch.float32),
94-
(torch.float32, torch.float64),
95-
]
96-
],
9775
ConsistencyConfig(
9876
v2_transforms.ToPILImage,
9977
legacy_transforms.ToPILImage,

0 commit comments

Comments
 (0)