1
1
import itertools
2
- import pathlib
3
- import pickle
4
2
import random
5
3
6
4
import numpy as np
11
9
import torchvision .transforms .v2 as transforms
12
10
13
11
from common_utils import assert_equal , cpu_and_cuda
14
- from torch .utils ._pytree import tree_flatten , tree_unflatten
15
12
from torchvision import tv_tensors
16
13
from torchvision .ops .boxes import box_iou
17
14
from torchvision .transforms .functional import to_pil_image
18
- from torchvision .transforms .v2 import functional as F
19
- from torchvision .transforms .v2 ._utils import check_type , is_pure_tensor , query_chw
20
- from transforms_v2_legacy_utils import (
21
- make_bounding_boxes ,
22
- make_detection_mask ,
23
- make_image ,
24
- make_images ,
25
- make_multiple_bounding_boxes ,
26
- make_segmentation_mask ,
27
- make_video ,
28
- make_videos ,
29
- )
15
+ from torchvision .transforms .v2 ._utils import is_pure_tensor
16
+ from transforms_v2_legacy_utils import make_bounding_boxes , make_detection_mask , make_image , make_images , make_videos
30
17
31
18
32
19
def make_vanilla_tensor_images (* args , ** kwargs ):
@@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs):
41
28
yield to_pil_image (image )
42
29
43
30
44
- def make_vanilla_tensor_bounding_boxes (* args , ** kwargs ):
45
- for bounding_boxes in make_multiple_bounding_boxes (* args , ** kwargs ):
46
- yield bounding_boxes .data
47
-
48
-
49
31
def parametrize (transforms_with_inputs ):
50
32
return pytest .mark .parametrize (
51
33
("transform" , "input" ),
@@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs):
61
43
)
62
44
63
45
64
- def auto_augment_adapter (transform , input , device ):
65
- adapted_input = {}
66
- image_or_video_found = False
67
- for key , value in input .items ():
68
- if isinstance (value , (tv_tensors .BoundingBoxes , tv_tensors .Mask )):
69
- # AA transforms don't support bounding boxes or masks
70
- continue
71
- elif check_type (value , (tv_tensors .Image , tv_tensors .Video , is_pure_tensor , PIL .Image .Image )):
72
- if image_or_video_found :
73
- # AA transforms only support a single image or video
74
- continue
75
- image_or_video_found = True
76
- adapted_input [key ] = value
77
- return adapted_input
78
-
79
-
80
- def linear_transformation_adapter (transform , input , device ):
81
- flat_inputs = list (input .values ())
82
- c , h , w = query_chw (
83
- [
84
- item
85
- for item , needs_transform in zip (flat_inputs , transforms .Transform ()._needs_transform_list (flat_inputs ))
86
- if needs_transform
87
- ]
88
- )
89
- num_elements = c * h * w
90
- transform .transformation_matrix = torch .randn ((num_elements , num_elements ), device = device )
91
- transform .mean_vector = torch .randn ((num_elements ,), device = device )
92
- return {key : value for key , value in input .items () if not isinstance (value , PIL .Image .Image )}
93
-
94
-
95
- def normalize_adapter (transform , input , device ):
96
- adapted_input = {}
97
- for key , value in input .items ():
98
- if isinstance (value , PIL .Image .Image ):
99
- # normalize doesn't support PIL images
100
- continue
101
- elif check_type (value , (tv_tensors .Image , tv_tensors .Video , is_pure_tensor )):
102
- # normalize doesn't support integer images
103
- value = F .to_dtype (value , torch .float32 , scale = True )
104
- adapted_input [key ] = value
105
- return adapted_input
106
-
107
-
108
- class TestSmoke :
109
- @pytest .mark .parametrize (
110
- ("transform" , "adapter" ),
111
- [
112
- (transforms .RandomErasing (p = 1.0 ), None ),
113
- (transforms .AugMix (), auto_augment_adapter ),
114
- (transforms .AutoAugment (), auto_augment_adapter ),
115
- (transforms .RandAugment (), auto_augment_adapter ),
116
- (transforms .TrivialAugmentWide (), auto_augment_adapter ),
117
- (transforms .ColorJitter (brightness = 0.1 , contrast = 0.2 , saturation = 0.3 , hue = 0.15 ), None ),
118
- (transforms .RandomAdjustSharpness (sharpness_factor = 0.5 , p = 1.0 ), None ),
119
- (transforms .RandomAutocontrast (p = 1.0 ), None ),
120
- (transforms .RandomEqualize (p = 1.0 ), None ),
121
- (transforms .RandomInvert (p = 1.0 ), None ),
122
- (transforms .RandomChannelPermutation (), None ),
123
- (transforms .RandomPosterize (bits = 4 , p = 1.0 ), None ),
124
- (transforms .RandomSolarize (threshold = 0.5 , p = 1.0 ), None ),
125
- (transforms .CenterCrop ([16 , 16 ]), None ),
126
- (transforms .ElasticTransform (sigma = 1.0 ), None ),
127
- (transforms .Pad (4 ), None ),
128
- (transforms .RandomAffine (degrees = 30.0 ), None ),
129
- (transforms .RandomCrop ([16 , 16 ], pad_if_needed = True ), None ),
130
- (transforms .RandomHorizontalFlip (p = 1.0 ), None ),
131
- (transforms .RandomPerspective (p = 1.0 ), None ),
132
- (transforms .RandomResize (min_size = 10 , max_size = 20 , antialias = True ), None ),
133
- (transforms .RandomResizedCrop ([16 , 16 ], antialias = True ), None ),
134
- (transforms .RandomRotation (degrees = 30 ), None ),
135
- (transforms .RandomShortestSize (min_size = 10 , antialias = True ), None ),
136
- (transforms .RandomVerticalFlip (p = 1.0 ), None ),
137
- (transforms .Resize ([16 , 16 ], antialias = True ), None ),
138
- (transforms .ScaleJitter ((16 , 16 ), scale_range = (0.8 , 1.2 ), antialias = True ), None ),
139
- (transforms .ClampBoundingBoxes (), None ),
140
- (transforms .ConvertBoundingBoxFormat (tv_tensors .BoundingBoxFormat .CXCYWH ), None ),
141
- (transforms .ConvertImageDtype (), None ),
142
- (transforms .GaussianBlur (kernel_size = 3 ), None ),
143
- (
144
- transforms .LinearTransformation (
145
- # These are just dummy values that will be filled by the adapter. We can't define them upfront,
146
- # because for we neither know the spatial size nor the device at this point
147
- transformation_matrix = torch .empty ((1 , 1 )),
148
- mean_vector = torch .empty ((1 ,)),
149
- ),
150
- linear_transformation_adapter ,
151
- ),
152
- (transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]), normalize_adapter ),
153
- (transforms .ToDtype (torch .float64 ), None ),
154
- (transforms .UniformTemporalSubsample (num_samples = 2 ), None ),
155
- ],
156
- ids = lambda transform : type (transform ).__name__ ,
157
- )
158
- @pytest .mark .parametrize ("container_type" , [dict , list , tuple ])
159
- @pytest .mark .parametrize (
160
- "image_or_video" ,
161
- [
162
- make_image (),
163
- make_video (),
164
- next (make_pil_images (color_spaces = ["RGB" ])),
165
- next (make_vanilla_tensor_images ()),
166
- ],
167
- )
168
- @pytest .mark .parametrize ("de_serialize" , [lambda t : t , lambda t : pickle .loads (pickle .dumps (t ))])
169
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
170
- def test_common (self , transform , adapter , container_type , image_or_video , de_serialize , device ):
171
- transform = de_serialize (transform )
172
-
173
- canvas_size = F .get_size (image_or_video )
174
- input = dict (
175
- image_or_video = image_or_video ,
176
- image_tv_tensor = make_image (size = canvas_size ),
177
- video_tv_tensor = make_video (size = canvas_size ),
178
- image_pil = next (make_pil_images (sizes = [canvas_size ], color_spaces = ["RGB" ])),
179
- bounding_boxes_xyxy = make_bounding_boxes (
180
- format = tv_tensors .BoundingBoxFormat .XYXY , canvas_size = canvas_size , batch_dims = (3 ,)
181
- ),
182
- bounding_boxes_xywh = make_bounding_boxes (
183
- format = tv_tensors .BoundingBoxFormat .XYWH , canvas_size = canvas_size , batch_dims = (4 ,)
184
- ),
185
- bounding_boxes_cxcywh = make_bounding_boxes (
186
- format = tv_tensors .BoundingBoxFormat .CXCYWH , canvas_size = canvas_size , batch_dims = (5 ,)
187
- ),
188
- bounding_boxes_degenerate_xyxy = tv_tensors .BoundingBoxes (
189
- [
190
- [0 , 0 , 0 , 0 ], # no height or width
191
- [0 , 0 , 0 , 1 ], # no height
192
- [0 , 0 , 1 , 0 ], # no width
193
- [2 , 0 , 1 , 1 ], # x1 > x2, y1 < y2
194
- [0 , 2 , 1 , 1 ], # x1 < x2, y1 > y2
195
- [2 , 2 , 1 , 1 ], # x1 > x2, y1 > y2
196
- ],
197
- format = tv_tensors .BoundingBoxFormat .XYXY ,
198
- canvas_size = canvas_size ,
199
- ),
200
- bounding_boxes_degenerate_xywh = tv_tensors .BoundingBoxes (
201
- [
202
- [0 , 0 , 0 , 0 ], # no height or width
203
- [0 , 0 , 0 , 1 ], # no height
204
- [0 , 0 , 1 , 0 ], # no width
205
- [0 , 0 , 1 , - 1 ], # negative height
206
- [0 , 0 , - 1 , 1 ], # negative width
207
- [0 , 0 , - 1 , - 1 ], # negative height and width
208
- ],
209
- format = tv_tensors .BoundingBoxFormat .XYWH ,
210
- canvas_size = canvas_size ,
211
- ),
212
- bounding_boxes_degenerate_cxcywh = tv_tensors .BoundingBoxes (
213
- [
214
- [0 , 0 , 0 , 0 ], # no height or width
215
- [0 , 0 , 0 , 1 ], # no height
216
- [0 , 0 , 1 , 0 ], # no width
217
- [0 , 0 , 1 , - 1 ], # negative height
218
- [0 , 0 , - 1 , 1 ], # negative width
219
- [0 , 0 , - 1 , - 1 ], # negative height and width
220
- ],
221
- format = tv_tensors .BoundingBoxFormat .CXCYWH ,
222
- canvas_size = canvas_size ,
223
- ),
224
- detection_mask = make_detection_mask (size = canvas_size ),
225
- segmentation_mask = make_segmentation_mask (size = canvas_size ),
226
- int = 0 ,
227
- float = 0.0 ,
228
- bool = True ,
229
- none = None ,
230
- str = "str" ,
231
- path = pathlib .Path .cwd (),
232
- object = object (),
233
- tensor = torch .empty (5 ),
234
- array = np .empty (5 ),
235
- )
236
- if adapter is not None :
237
- input = adapter (transform , input , device )
238
-
239
- if container_type in {tuple , list }:
240
- input = container_type (input .values ())
241
-
242
- input_flat , input_spec = tree_flatten (input )
243
- input_flat = [item .to (device ) if isinstance (item , torch .Tensor ) else item for item in input_flat ]
244
- input = tree_unflatten (input_flat , input_spec )
245
-
246
- torch .manual_seed (0 )
247
- output = transform (input )
248
- output_flat , output_spec = tree_flatten (output )
249
-
250
- assert output_spec == input_spec
251
-
252
- for output_item , input_item , should_be_transformed in zip (
253
- output_flat , input_flat , transforms .Transform ()._needs_transform_list (input_flat )
254
- ):
255
- if should_be_transformed :
256
- assert type (output_item ) is type (input_item )
257
- else :
258
- assert output_item is input_item
259
-
260
- if isinstance (input_item , tv_tensors .BoundingBoxes ) and not isinstance (
261
- transform , transforms .ConvertBoundingBoxFormat
262
- ):
263
- assert output_item .format == input_item .format
264
-
265
- # Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
266
- # transform that does this), back into a valid one.
267
- # TODO: we should test that against all degenerate boxes above
268
- for format in list (tv_tensors .BoundingBoxFormat ):
269
- sample = dict (
270
- boxes = tv_tensors .BoundingBoxes ([[0 , 0 , 0 , 0 ]], format = format , canvas_size = (224 , 244 )),
271
- labels = torch .tensor ([3 ]),
272
- )
273
- assert transforms .SanitizeBoundingBoxes ()(sample )["boxes" ].shape == (0 , 4 )
274
-
275
-
276
46
@pytest .mark .parametrize (
277
47
"flat_inputs" ,
278
48
itertools .permutations (
@@ -543,39 +313,6 @@ def test__get_params(self, min_size, max_size):
543
313
assert shorter in min_size
544
314
545
315
546
- class TestLinearTransformation :
547
- def test_assertions (self ):
548
- with pytest .raises (ValueError , match = "transformation_matrix should be square" ):
549
- transforms .LinearTransformation (torch .rand (2 , 3 ), torch .rand (5 ))
550
-
551
- with pytest .raises (ValueError , match = "mean_vector should have the same length" ):
552
- transforms .LinearTransformation (torch .rand (3 , 3 ), torch .rand (5 ))
553
-
554
- @pytest .mark .parametrize (
555
- "inpt" ,
556
- [
557
- 122 * torch .ones (1 , 3 , 8 , 8 ),
558
- 122.0 * torch .ones (1 , 3 , 8 , 8 ),
559
- tv_tensors .Image (122 * torch .ones (1 , 3 , 8 , 8 )),
560
- PIL .Image .new ("RGB" , (8 , 8 ), (122 , 122 , 122 )),
561
- ],
562
- )
563
- def test__transform (self , inpt ):
564
-
565
- v = 121 * torch .ones (3 * 8 * 8 )
566
- m = torch .ones (3 * 8 * 8 , 3 * 8 * 8 )
567
- transform = transforms .LinearTransformation (m , v )
568
-
569
- if isinstance (inpt , PIL .Image .Image ):
570
- with pytest .raises (TypeError , match = "does not support PIL images" ):
571
- transform (inpt )
572
- else :
573
- output = transform (inpt )
574
- assert isinstance (output , torch .Tensor )
575
- assert output .unique () == 3 * 8 * 8
576
- assert output .dtype == inpt .dtype
577
-
578
-
579
316
class TestRandomResize :
580
317
def test__get_params (self ):
581
318
min_size = 3
0 commit comments