Skip to content

Commit c585a51

Browse files
Enable one-hot-encoded labels in MixUp and CutMix (#8427)
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent 778ce48 commit c585a51

File tree

2 files changed

+40
-23
lines changed

2 files changed

+40
-23
lines changed

test/test_transforms_v2.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -2169,26 +2169,30 @@ def test_image_correctness(self, brightness_factor):
21692169

21702170
class TestCutMixMixUp:
21712171
class DummyDataset:
2172-
def __init__(self, size, num_classes):
2172+
def __init__(self, size, num_classes, one_hot_labels):
21732173
self.size = size
21742174
self.num_classes = num_classes
2175+
self.one_hot_labels = one_hot_labels
21752176
assert size < num_classes
21762177

21772178
def __getitem__(self, idx):
21782179
img = torch.rand(3, 100, 100)
21792180
label = idx # This ensures all labels in a batch are unique and makes testing easier
2181+
if self.one_hot_labels:
2182+
label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
21802183
return img, label
21812184

21822185
def __len__(self):
21832186
return self.size
21842187

21852188
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
2186-
def test_supported_input_structure(self, T):
2189+
@pytest.mark.parametrize("one_hot_labels", (True, False))
2190+
def test_supported_input_structure(self, T, one_hot_labels):
21872191

21882192
batch_size = 32
21892193
num_classes = 100
21902194

2191-
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
2195+
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)
21922196

21932197
cutmix_mixup = T(num_classes=num_classes)
21942198

@@ -2198,7 +2202,7 @@ def test_supported_input_structure(self, T):
21982202
img, target = next(iter(dl))
21992203
input_img_size = img.shape[-3:]
22002204
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
2201-
assert target.shape == (batch_size,)
2205+
assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
22022206

22032207
def check_output(img, target):
22042208
assert img.shape == (batch_size, *input_img_size)
@@ -2209,7 +2213,7 @@ def check_output(img, target):
22092213

22102214
# After Dataloader, as unpacked input
22112215
img, target = next(iter(dl))
2212-
assert target.shape == (batch_size,)
2216+
assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
22132217
img, target = cutmix_mixup(img, target)
22142218
check_output(img, target)
22152219

@@ -2264,30 +2268,29 @@ def test_error(self, T):
22642268
with pytest.raises(ValueError, match="Could not infer where the labels are"):
22652269
cutmix_mixup({"img": imgs, "Nothing_else": 3})
22662270

2267-
with pytest.raises(ValueError, match="labels tensor should be of shape"):
2271+
with pytest.raises(ValueError, match="labels should be index based"):
22682272
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
22692273
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
22702274
cutmix_mixup(imgs)
22712275

22722276
with pytest.raises(ValueError, match="When using the default labels_getter"):
22732277
cutmix_mixup(imgs, "not_a_tensor")
22742278

2275-
with pytest.raises(ValueError, match="labels tensor should be of shape"):
2276-
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))
2277-
22782279
with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
22792280
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
22802281

22812282
with pytest.raises(ValueError, match="does not match the batch size of the labels"):
22822283
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
22832284

2284-
with pytest.raises(ValueError, match="labels tensor should be of shape"):
2285-
# The purpose of this check is more about documenting the current
2286-
# behaviour of what happens on a Compose(), rather than actually
2287-
# asserting the expected behaviour. We may support Compose() in the
2288-
# future, e.g. for 2 consecutive CutMix?
2289-
labels = torch.randint(0, num_classes, size=(batch_size,))
2290-
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
2285+
with pytest.raises(ValueError, match="When passing 2D labels"):
2286+
wrong_num_classes = num_classes + 1
2287+
T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))
2288+
2289+
with pytest.raises(ValueError, match="but got a tensor of shape"):
2290+
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))
2291+
2292+
with pytest.raises(ValueError, match="num_classes must be passed"):
2293+
T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))
22912294

22922295

22932296
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))

torchvision/transforms/v2/_augment.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import numbers
33
import warnings
4-
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
55

66
import PIL.Image
77
import torch
@@ -142,7 +142,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
142142

143143

144144
class _BaseMixUpCutMix(Transform):
145-
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
145+
def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
146146
super().__init__()
147147
self.alpha = float(alpha)
148148
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
@@ -162,10 +162,21 @@ def forward(self, *inputs):
162162
labels = self._labels_getter(inputs)
163163
if not isinstance(labels, torch.Tensor):
164164
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
165-
elif labels.ndim != 1:
165+
if labels.ndim not in (1, 2):
166166
raise ValueError(
167-
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
167+
f"labels should be index based with shape (batch_size,) "
168+
f"or probability based with shape (batch_size, num_classes), "
169+
f"but got a tensor of shape {labels.shape} instead."
168170
)
171+
if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
172+
raise ValueError(
173+
f"When passing 2D labels, "
174+
f"the number of elements in last dimension must match num_classes: "
175+
f"{labels.shape[-1]} != {self.num_classes}. "
176+
f"You can Leave num_classes to None."
177+
)
178+
if labels.ndim == 1 and self.num_classes is None:
179+
raise ValueError("num_classes must be passed if the labels are index-based (1D)")
169180

170181
params = {
171182
"labels": labels,
@@ -198,7 +209,8 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
198209
)
199210

200211
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
201-
label = one_hot(label, num_classes=self.num_classes)
212+
if label.ndim == 1:
213+
label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
202214
if not label.dtype.is_floating_point:
203215
label = label.float()
204216
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
@@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
223235
224236
Args:
225237
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
226-
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
238+
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
239+
Can be None only if the labels are already one-hot-encoded.
227240
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
228241
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
229242
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
@@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
271284
272285
Args:
273286
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
274-
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
287+
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
288+
Can be None only if the labels are already one-hot-encoded.
275289
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
276290
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
277291
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.

0 commit comments

Comments
 (0)