@@ -2169,26 +2169,30 @@ def test_image_correctness(self, brightness_factor):
2169
2169
2170
2170
class TestCutMixMixUp :
2171
2171
class DummyDataset :
2172
- def __init__ (self , size , num_classes ):
2172
+ def __init__ (self , size , num_classes , one_hot_labels ):
2173
2173
self .size = size
2174
2174
self .num_classes = num_classes
2175
+ self .one_hot_labels = one_hot_labels
2175
2176
assert size < num_classes
2176
2177
2177
2178
def __getitem__ (self , idx ):
2178
2179
img = torch .rand (3 , 100 , 100 )
2179
2180
label = idx # This ensures all labels in a batch are unique and makes testing easier
2181
+ if self .one_hot_labels :
2182
+ label = torch .nn .functional .one_hot (torch .tensor (label ), num_classes = self .num_classes )
2180
2183
return img , label
2181
2184
2182
2185
def __len__ (self ):
2183
2186
return self .size
2184
2187
2185
2188
@pytest .mark .parametrize ("T" , [transforms .CutMix , transforms .MixUp ])
2186
- def test_supported_input_structure (self , T ):
2189
+ @pytest .mark .parametrize ("one_hot_labels" , (True , False ))
2190
+ def test_supported_input_structure (self , T , one_hot_labels ):
2187
2191
2188
2192
batch_size = 32
2189
2193
num_classes = 100
2190
2194
2191
- dataset = self .DummyDataset (size = batch_size , num_classes = num_classes )
2195
+ dataset = self .DummyDataset (size = batch_size , num_classes = num_classes , one_hot_labels = one_hot_labels )
2192
2196
2193
2197
cutmix_mixup = T (num_classes = num_classes )
2194
2198
@@ -2198,7 +2202,7 @@ def test_supported_input_structure(self, T):
2198
2202
img , target = next (iter (dl ))
2199
2203
input_img_size = img .shape [- 3 :]
2200
2204
assert isinstance (img , torch .Tensor ) and isinstance (target , torch .Tensor )
2201
- assert target .shape == (batch_size ,)
2205
+ assert target .shape == (batch_size , num_classes ) if one_hot_labels else ( batch_size , )
2202
2206
2203
2207
def check_output (img , target ):
2204
2208
assert img .shape == (batch_size , * input_img_size )
@@ -2209,7 +2213,7 @@ def check_output(img, target):
2209
2213
2210
2214
# After Dataloader, as unpacked input
2211
2215
img , target = next (iter (dl ))
2212
- assert target .shape == (batch_size ,)
2216
+ assert target .shape == (batch_size , num_classes ) if one_hot_labels else ( batch_size , )
2213
2217
img , target = cutmix_mixup (img , target )
2214
2218
check_output (img , target )
2215
2219
@@ -2264,30 +2268,29 @@ def test_error(self, T):
2264
2268
with pytest .raises (ValueError , match = "Could not infer where the labels are" ):
2265
2269
cutmix_mixup ({"img" : imgs , "Nothing_else" : 3 })
2266
2270
2267
- with pytest .raises (ValueError , match = "labels tensor should be of shape " ):
2271
+ with pytest .raises (ValueError , match = "labels should be index based " ):
2268
2272
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
2269
2273
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
2270
2274
cutmix_mixup (imgs )
2271
2275
2272
2276
with pytest .raises (ValueError , match = "When using the default labels_getter" ):
2273
2277
cutmix_mixup (imgs , "not_a_tensor" )
2274
2278
2275
- with pytest .raises (ValueError , match = "labels tensor should be of shape" ):
2276
- cutmix_mixup (imgs , torch .randint (0 , 2 , size = (2 , 3 )))
2277
-
2278
2279
with pytest .raises (ValueError , match = "Expected a batched input with 4 dims" ):
2279
2280
cutmix_mixup (imgs [None , None ], torch .randint (0 , num_classes , size = (batch_size ,)))
2280
2281
2281
2282
with pytest .raises (ValueError , match = "does not match the batch size of the labels" ):
2282
2283
cutmix_mixup (imgs , torch .randint (0 , num_classes , size = (batch_size + 1 ,)))
2283
2284
2284
- with pytest .raises (ValueError , match = "labels tensor should be of shape" ):
2285
- # The purpose of this check is more about documenting the current
2286
- # behaviour of what happens on a Compose(), rather than actually
2287
- # asserting the expected behaviour. We may support Compose() in the
2288
- # future, e.g. for 2 consecutive CutMix?
2289
- labels = torch .randint (0 , num_classes , size = (batch_size ,))
2290
- transforms .Compose ([cutmix_mixup , cutmix_mixup ])(imgs , labels )
2285
+ with pytest .raises (ValueError , match = "When passing 2D labels" ):
2286
+ wrong_num_classes = num_classes + 1
2287
+ T (alpha = 0.5 , num_classes = num_classes )(imgs , torch .randint (0 , 2 , size = (batch_size , wrong_num_classes )))
2288
+
2289
+ with pytest .raises (ValueError , match = "but got a tensor of shape" ):
2290
+ cutmix_mixup (imgs , torch .randint (0 , 2 , size = (2 , 3 , 4 )))
2291
+
2292
+ with pytest .raises (ValueError , match = "num_classes must be passed" ):
2293
+ T (alpha = 0.5 )(imgs , torch .randint (0 , num_classes , size = (batch_size ,)))
2291
2294
2292
2295
2293
2296
@pytest .mark .parametrize ("key" , ("labels" , "LABELS" , "LaBeL" , "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT" ))
0 commit comments