Skip to content

Commit 48f8473

Browse files
authored
port tests for type conversion transforms (#8003)
1 parent ee28bb3 commit 48f8473

File tree

4 files changed

+82
-79
lines changed

4 files changed

+82
-79
lines changed

test/test_transforms_v2.py

-55
Original file line numberDiff line numberDiff line change
@@ -122,61 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker):
122122
t(inpt)
123123

124124

125-
class TestToImage:
126-
@pytest.mark.parametrize(
127-
"inpt_type",
128-
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
129-
)
130-
def test__transform(self, inpt_type, mocker):
131-
fn = mocker.patch(
132-
"torchvision.transforms.v2.functional.to_image",
133-
return_value=torch.rand(1, 3, 8, 8),
134-
)
135-
136-
inpt = mocker.MagicMock(spec=inpt_type)
137-
transform = transforms.ToImage()
138-
transform(inpt)
139-
if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int):
140-
assert fn.call_count == 0
141-
else:
142-
fn.assert_called_once_with(inpt)
143-
144-
145-
class TestToPILImage:
146-
@pytest.mark.parametrize(
147-
"inpt_type",
148-
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
149-
)
150-
def test__transform(self, inpt_type, mocker):
151-
fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
152-
153-
inpt = mocker.MagicMock(spec=inpt_type)
154-
transform = transforms.ToPILImage()
155-
transform(inpt)
156-
if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int):
157-
assert fn.call_count == 0
158-
else:
159-
fn.assert_called_once_with(inpt, mode=transform.mode)
160-
161-
162-
class TestToTensor:
163-
@pytest.mark.parametrize(
164-
"inpt_type",
165-
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
166-
)
167-
def test__transform(self, inpt_type, mocker):
168-
fn = mocker.patch("torchvision.transforms.functional.to_tensor")
169-
170-
inpt = mocker.MagicMock(spec=inpt_type)
171-
with pytest.warns(UserWarning, match="deprecated and will be removed"):
172-
transform = transforms.ToTensor()
173-
transform(inpt)
174-
if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int):
175-
assert fn.call_count == 0
176-
else:
177-
fn.assert_called_once_with(inpt)
178-
179-
180125
class TestContainers:
181126
@pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder])
182127
def test_assertions(self, transform_cls):

test/test_transforms_v2_consistency.py

-23
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,6 @@ def __init__(
7272
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
7373

7474
CONSISTENCY_CONFIGS = [
75-
ConsistencyConfig(
76-
v2_transforms.ToPILImage,
77-
legacy_transforms.ToPILImage,
78-
[NotScriptableArgsKwargs()],
79-
make_images_kwargs=dict(
80-
color_spaces=[
81-
"GRAY",
82-
"GRAY_ALPHA",
83-
"RGB",
84-
"RGBA",
85-
],
86-
extra_dims=[()],
87-
),
88-
supports_pil=False,
89-
),
9075
ConsistencyConfig(
9176
v2_transforms.Lambda,
9277
legacy_transforms.Lambda,
@@ -97,14 +82,6 @@ def __init__(
9782
# images given that the transform does nothing but call it anyway.
9883
supports_pil=False,
9984
),
100-
ConsistencyConfig(
101-
v2_transforms.PILToTensor,
102-
legacy_transforms.PILToTensor,
103-
),
104-
ConsistencyConfig(
105-
v2_transforms.ToTensor,
106-
legacy_transforms.ToTensor,
107-
),
10885
ConsistencyConfig(
10986
v2_transforms.Compose,
11087
legacy_transforms.Compose,

test/test_transforms_v2_refactored.py

+79
Original file line numberDiff line numberDiff line change
@@ -5047,3 +5047,82 @@ def test_transform_error_cuda(self):
50475047
ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector"
50485048
):
50495049
transform(input)
5050+
5051+
5052+
def make_image_numpy(*args, **kwargs):
5053+
image = make_image_tensor(*args, **kwargs)
5054+
return image.permute((1, 2, 0)).numpy()
5055+
5056+
5057+
class TestToImage:
5058+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
5059+
@pytest.mark.parametrize("fn", [F.to_image, transform_cls_to_functional(transforms.ToImage)])
5060+
def test_functional_and_transform(self, make_input, fn):
5061+
input = make_input()
5062+
output = fn(input)
5063+
5064+
assert isinstance(output, tv_tensors.Image)
5065+
5066+
input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
5067+
assert F.get_size(output) == input_size
5068+
5069+
if isinstance(input, torch.Tensor):
5070+
assert output.data_ptr() == input.data_ptr()
5071+
5072+
def test_functional_error(self):
5073+
with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"):
5074+
F.to_image(object())
5075+
5076+
5077+
class TestToPILImage:
5078+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_numpy])
5079+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5080+
@pytest.mark.parametrize("fn", [F.to_pil_image, transform_cls_to_functional(transforms.ToPILImage)])
5081+
def test_functional_and_transform(self, make_input, color_space, fn):
5082+
input = make_input(color_space=color_space)
5083+
output = fn(input)
5084+
5085+
assert isinstance(output, PIL.Image.Image)
5086+
5087+
input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
5088+
assert F.get_size(output) == input_size
5089+
5090+
def test_functional_error(self):
5091+
with pytest.raises(TypeError, match="pic should be Tensor or ndarray"):
5092+
F.to_pil_image(object())
5093+
5094+
for ndim in [1, 4]:
5095+
with pytest.raises(ValueError, match="pic should be 2/3 dimensional"):
5096+
F.to_pil_image(torch.empty(*[1] * ndim))
5097+
5098+
with pytest.raises(ValueError, match="pic should not have > 4 channels"):
5099+
num_channels = 5
5100+
F.to_pil_image(torch.empty(num_channels, 1, 1))
5101+
5102+
5103+
class TestToTensor:
5104+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
5105+
def test_smoke(self, make_input):
5106+
with pytest.warns(UserWarning, match="deprecated and will be removed"):
5107+
transform = transforms.ToTensor()
5108+
5109+
input = make_input()
5110+
output = transform(input)
5111+
5112+
input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
5113+
assert F.get_size(output) == input_size
5114+
5115+
5116+
class TestPILToTensor:
5117+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5118+
@pytest.mark.parametrize("fn", [F.pil_to_tensor, transform_cls_to_functional(transforms.PILToTensor)])
5119+
def test_functional_and_transform(self, color_space, fn):
5120+
input = make_image_pil(color_space=color_space)
5121+
output = fn(input)
5122+
5123+
assert isinstance(output, torch.Tensor) and not isinstance(output, tv_tensors.TVTensor)
5124+
assert F.get_size(output) == F.get_size(input)
5125+
5126+
def test_functional_error(self):
5127+
with pytest.raises(TypeError, match="pic should be PIL Image"):
5128+
F.pil_to_tensor(object())

torchvision/transforms/v2/functional/_type_conversion.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso
1717
elif isinstance(inpt, torch.Tensor):
1818
output = inpt
1919
else:
20-
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
20+
raise TypeError(
21+
f"Input can either be a pure Tensor, a numpy array, or a PIL image, but got {type(inpt)} instead."
22+
)
2123
return tv_tensors.Image(output)
2224

2325

0 commit comments

Comments
 (0)