From dc069039c1cd0e5afa4fed205577ff9bea2c3500 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Aug 2023 09:50:41 +0200 Subject: [PATCH 1/2] allow dispatch to PIL image subclasses --- test/test_transforms_v2_refactored.py | 33 +++++++++++++------ .../transforms/v2/functional/_utils.py | 23 +++++-------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index fa1ed05b84b..6ee5e979a7e 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -3,6 +3,7 @@ import inspect import math import re +from pathlib import Path from unittest import mock import numpy as np @@ -2126,16 +2127,10 @@ class TestGetKernel: datapoints.Video: F.resize_video, } - def test_unsupported_types(self): - class MyTensor(torch.Tensor): - pass - - class MyPILImage(PIL.Image.Image): - pass - - for input_type in [str, int, object, MyTensor, MyPILImage]: - with pytest.raises(TypeError, match="supports inputs of type"): - _get_kernel(F.resize, input_type) + @pytest.mark.parametrize("input_type", [str, int, object]) + def test_unsupported_types(self, input_type): + with pytest.raises(TypeError, match="supports inputs of type"): + _get_kernel(F.resize, input_type) def test_exact_match(self): # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the @@ -2197,6 +2192,24 @@ def resize_my_datapoint(): assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint + def test_pil_image_subclass(self): + opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg") + loaded_image = opened_image.convert("RGB") + + # check the assumptions + assert isinstance(opened_image, PIL.Image.Image) + assert type(opened_image) is not PIL.Image.Image + + assert type(loaded_image) is PIL.Image.Image + + size = [17, 11] + for image in [opened_image, loaded_image]: + kernel = _get_kernel(F.resize, type(image)) + + output = kernel(image, size=size) + + assert F.get_size(output) == size + class TestPermuteChannels: _DEFAULT_PERMUTATION = [2, 0, 1] diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 1f5c6f5eea0..cb4313f74ee 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False): if not registry: raise ValueError(f"No kernel registered for functional {functional.__name__}.") - # In case we have an exact type match, we take a shortcut. - if input_type in registry: - return registry[input_type] - - # In case of datapoints, we check if we have a kernel for a superclass registered - if issubclass(input_type, datapoints.Datapoint): - # Since we have already checked for an exact match above, we can start the traversal at the superclass. - for cls in input_type.__mro__[1:]: - if cls is datapoints.Datapoint: - # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the - # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't - # allow kernels to be registered for datapoints.Datapoint anyway. - break - elif cls in registry: - return registry[cls] + for cls in input_type.__mro__: + if cls in registry: + return registry[cls] + elif issubclass(input_type, datapoints.Datapoint) and cls is datapoints.Datapoint: + # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the + # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't + # allow kernels to be registered for datapoints.Datapoint anyway. + break if allow_passthrough: return lambda inpt, *args, **kwargs: inpt From c5a1a07369b2df1b6268ef0b2eb859f6a2997eb4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Aug 2023 11:00:37 +0200 Subject: [PATCH 2/2] simplify check --- torchvision/transforms/v2/functional/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index cb4313f74ee..dd1fc81fb83 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -103,7 +103,7 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False): for cls in input_type.__mro__: if cls in registry: return registry[cls] - elif issubclass(input_type, datapoints.Datapoint) and cls is datapoints.Datapoint: + elif cls is datapoints.Datapoint: # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't # allow kernels to be registered for datapoints.Datapoint anyway.