diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4f9a08b8412..03f9e906675 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -122,61 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker): t(inpt) -class TestToImage: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch( - "torchvision.transforms.v2.functional.to_image", - return_value=torch.rand(1, 3, 8, 8), - ) - - inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToImage() - transform(inpt) - if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt) - - -class TestToPILImage: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image") - - inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToPILImage() - transform(inpt) - if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt, mode=transform.mode) - - -class TestToTensor: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.transforms.functional.to_tensor") - - inpt = mocker.MagicMock(spec=inpt_type) - with pytest.warns(UserWarning, match="deprecated and will be removed"): - transform = transforms.ToTensor() - transform(inpt) - if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt) - - class TestContainers: @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]) def test_assertions(self, transform_cls): diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index b4ce189e758..397d42101ce 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -72,21 +72,6 @@ def __init__( LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) CONSISTENCY_CONFIGS = [ - ConsistencyConfig( - v2_transforms.ToPILImage, - legacy_transforms.ToPILImage, - [NotScriptableArgsKwargs()], - make_images_kwargs=dict( - color_spaces=[ - "GRAY", - "GRAY_ALPHA", - "RGB", - "RGBA", - ], - extra_dims=[()], - ), - supports_pil=False, - ), ConsistencyConfig( v2_transforms.Lambda, legacy_transforms.Lambda, @@ -97,14 +82,6 @@ def __init__( # images given that the transform does nothing but call it anyway. supports_pil=False, ), - ConsistencyConfig( - v2_transforms.PILToTensor, - legacy_transforms.PILToTensor, - ), - ConsistencyConfig( - v2_transforms.ToTensor, - legacy_transforms.ToTensor, - ), ConsistencyConfig( v2_transforms.Compose, legacy_transforms.Compose, diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 079a8e619d9..b700b159ec5 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -5047,3 +5047,82 @@ def test_transform_error_cuda(self): ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector" ): transform(input) + + +def make_image_numpy(*args, **kwargs): + image = make_image_tensor(*args, **kwargs) + return image.permute((1, 2, 0)).numpy() + + +class TestToImage: + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy]) + @pytest.mark.parametrize("fn", [F.to_image, transform_cls_to_functional(transforms.ToImage)]) + def test_functional_and_transform(self, make_input, fn): + input = make_input() + output = fn(input) + + assert isinstance(output, tv_tensors.Image) + + input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input) + assert F.get_size(output) == input_size + + if isinstance(input, torch.Tensor): + assert output.data_ptr() == input.data_ptr() + + def test_functional_error(self): + with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"): + F.to_image(object()) + + +class TestToPILImage: + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_numpy]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("fn", [F.to_pil_image, transform_cls_to_functional(transforms.ToPILImage)]) + def test_functional_and_transform(self, make_input, color_space, fn): + input = make_input(color_space=color_space) + output = fn(input) + + assert isinstance(output, PIL.Image.Image) + + input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input) + assert F.get_size(output) == input_size + + def test_functional_error(self): + with pytest.raises(TypeError, match="pic should be Tensor or ndarray"): + F.to_pil_image(object()) + + for ndim in [1, 4]: + with pytest.raises(ValueError, match="pic should be 2/3 dimensional"): + F.to_pil_image(torch.empty(*[1] * ndim)) + + with pytest.raises(ValueError, match="pic should not have > 4 channels"): + num_channels = 5 + F.to_pil_image(torch.empty(num_channels, 1, 1)) + + +class TestToTensor: + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy]) + def test_smoke(self, make_input): + with pytest.warns(UserWarning, match="deprecated and will be removed"): + transform = transforms.ToTensor() + + input = make_input() + output = transform(input) + + input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input) + assert F.get_size(output) == input_size + + +class TestPILToTensor: + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("fn", [F.pil_to_tensor, transform_cls_to_functional(transforms.PILToTensor)]) + def test_functional_and_transform(self, color_space, fn): + input = make_image_pil(color_space=color_space) + output = fn(input) + + assert isinstance(output, torch.Tensor) and not isinstance(output, tv_tensors.TVTensor) + assert F.get_size(output) == F.get_size(input) + + def test_functional_error(self): + with pytest.raises(TypeError, match="pic should be PIL Image"): + F.pil_to_tensor(object()) diff --git a/torchvision/transforms/v2/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py index 062f85198ee..02aeda83df3 100644 --- a/torchvision/transforms/v2/functional/_type_conversion.py +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -17,7 +17,9 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso elif isinstance(inpt, torch.Tensor): output = inpt else: - raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") + raise TypeError( + f"Input can either be a pure Tensor, a numpy array, or a PIL image, but got {type(inpt)} instead." + ) return tv_tensors.Image(output)