From d4f16cb3fe5093f22ec834a7998bd49800084e1c Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Thu, 2 Nov 2023 12:02:00 -0500 Subject: [PATCH 1/8] Added torch compile checks for functional ops --- test/test_transforms_v2.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4f8d0027bd6..06f10fcca71 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -186,7 +186,21 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs): functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs): +def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): + """Checks if the functional can be torch compiled and the compiled version can be called without error.""" + if not isinstance(input, tv_tensors.Image): + return + + functional_compiled = torch.compile(functional) + functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs) + + explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs) + # TODO: Set expected values to 2, 1 once fixed the graph break related to function registration + assert explanation.graph_count == 2 + assert explanation.graph_break_count == 1 + + +def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs): unknown_input = object() with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): functional(unknown_input, *args, **kwargs) @@ -204,6 +218,9 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar if check_scripted_smoke: _check_functional_scripted_smoke(functional, input, *args, **kwargs) + if check_torch_compile_smoke: + _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) + def check_functional_kernel_signature_match(functional, *, kernel, input_type): """Checks if the signature of the functional matches the kernel signature.""" @@ -656,6 +673,7 @@ def test_functional(self, size, make_input): size=size, antialias=True, check_scripted_smoke=not isinstance(size, int), + check_torch_compile_smoke=False, ) @pytest.mark.parametrize( @@ -3469,7 +3487,12 @@ def test_kernel(self, kernel, make_input): ) def test_functional(self, make_input): check_functional( - F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True + F.resized_crop, + make_input(self.INPUT_SIZE), + **self.CROP_KWARGS, + size=self.OUTPUT_SIZE, + antialias=True, + check_torch_compile_smoke=False, ) @pytest.mark.parametrize( @@ -3949,7 +3972,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS) + check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS, check_torch_compile_smoke=False) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -4106,7 +4129,7 @@ def test_kernel_video(self): @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) def test_functional(self, make_input): - check_functional(F.equalize, make_input()) + check_functional(F.equalize, make_input(), check_torch_compile_smoke=False) @pytest.mark.parametrize( ("kernel", "input_type"), From d5ca7b297ee5afc881a341f4a57a52f5ff14dd21 Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Fri, 3 Nov 2023 12:18:13 -0500 Subject: [PATCH 2/8] Fixed failing tests --- test/test_transforms_v2.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 06f10fcca71..afafeb65e71 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -7,6 +7,7 @@ import pickle import random import re +import sys from copy import deepcopy from pathlib import Path from unittest import mock @@ -195,9 +196,9 @@ def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs) explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs) - # TODO: Set expected values to 2, 1 once fixed the graph break related to function registration - assert explanation.graph_count == 2 - assert explanation.graph_break_count == 1 + # TODO: Set expected values to 1, 0 once fixed the graph break related to function registration + assert explanation.graph_count in (1, 2) + assert explanation.graph_break_count in (0, 1) def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs): @@ -218,7 +219,8 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_ if check_scripted_smoke: _check_functional_scripted_smoke(functional, input, *args, **kwargs) - if check_torch_compile_smoke: + # Skip check on Windows as torch.compile does not work on Win32 + if check_torch_compile_smoke and sys.platform != "win32": _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) @@ -1982,6 +1984,7 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): make_input(dtype=input_dtype, device=device), dtype=output_dtype, scale=scale, + ) @pytest.mark.parametrize( From 6fce5da7da61925f8847b275bf6e4b229d0d9333 Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Fri, 3 Nov 2023 12:24:26 -0500 Subject: [PATCH 3/8] Fixed lint --- test/test_transforms_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index afafeb65e71..dac560d6b41 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1984,7 +1984,6 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): make_input(dtype=input_dtype, device=device), dtype=output_dtype, scale=scale, - ) @pytest.mark.parametrize( From 829a2d3da334185026a000e176db20a9552f4f45 Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Mon, 6 Nov 2023 07:10:33 -0600 Subject: [PATCH 4/8] Fixed depr warning problem --- test/test_transforms_v2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index dac560d6b41..5ed2fe76ce0 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -221,7 +221,13 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_ # Skip check on Windows as torch.compile does not work on Win32 if check_torch_compile_smoke and sys.platform != "win32": - _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) + # Temporary fix to catch deprectation warning + # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged: + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) def check_functional_kernel_signature_match(functional, *, kernel, input_type): From 7f91053e219946189acb38c987e36c1bdfcb607c Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Tue, 7 Nov 2023 04:38:22 -0600 Subject: [PATCH 5/8] Removed temporary fixes --- test/test_transforms_v2.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5ed2fe76ce0..38136152f39 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -221,13 +221,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_ # Skip check on Windows as torch.compile does not work on Win32 if check_torch_compile_smoke and sys.platform != "win32": - # Temporary fix to catch deprectation warning - # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged: - import warnings - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) + _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) def check_functional_kernel_signature_match(functional, *, kernel, input_type): @@ -4137,7 +4131,7 @@ def test_kernel_video(self): @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) def test_functional(self, make_input): - check_functional(F.equalize, make_input(), check_torch_compile_smoke=False) + check_functional(F.equalize, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), From 84100215138a7d78a5396bec1e688bd27f7e53f7 Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Wed, 8 Nov 2023 02:59:51 -0600 Subject: [PATCH 6/8] Enabled torch compile functional tests for boxes, masks and video Annotated exceptions with encountered errors --- test/test_transforms_v2.py | 110 ++++++++++++++++++++++++++++++++++--- 1 file changed, 102 insertions(+), 8 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 38136152f39..4e3c6bf49dc 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -189,7 +189,7 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs): def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): """Checks if the functional can be torch compiled and the compiled version can be called without error.""" - if not isinstance(input, tv_tensors.Image): + if not isinstance(input, torch.Tensor): return functional_compiled = torch.compile(functional) @@ -1162,7 +1162,22 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) + # TODO: Remove this when fixed + # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py:1573: in UNPACK_SEQUENCE + # assert len(val) == inst.argval + # E AssertionError: + # E + # E from user code: + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 392, in resume_in_affine + # E return kernel( + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 692, in affine_image + # E return _apply_grid_transform(image, grid, interpolation.value, fill=fill) + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform + # E num_channels, input_height, input_width = input_shape[-3:] + check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True + check_functional( + F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke + ) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1586,7 +1601,12 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) + # TODO: Remove this when fixed + # Error is the same as for TestAffine + check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True + check_functional( + F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke + ) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2651,8 +2671,38 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): + # TODO: Remove this when fixed + # - TestElastic.test_functional[make_bounding_boxes]: + # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:410: in _fn + # return fn(*args, **kwargs) + # torchvision/transforms/v2/functional/_geometry.py:1705: in elastic + # kernel = _get_kernel(elastic, type(inpt)) + # torchvision/transforms/v2/functional/_geometry.py:1706: in resume_in_elastic + # return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + # torchvision/transforms/v2/functional/_geometry.py:1741: in elastic_image + # raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + # torchvision/transforms/v2/functional/_geometry.py:1741: in resume_in_elastic_image + # raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + # E ValueError: Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2]) + # + # - TestElastic.test_functional[make_segmentation_mask]: + # E AssertionError: + # E + # E from user code: + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1706, in resume_in_elastic + # E return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1746, in elastic_image + # E output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform + # E num_channels, input_height, input_width = input_shape[-3:] + check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True input = make_input() - check_functional(F.elastic, input, displacement=self._make_displacement(input)) + check_functional( + F.elastic, + input, + displacement=self._make_displacement(input), + check_torch_compile_smoke=check_torch_compile_smoke, + ) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2765,7 +2815,23 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS) + # TODO: Remove this when fixed + # E AssertionError: + # E + # E from user code: + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1324, in resume_in_crop + # E return kernel(inpt, top=top, left=left, height=height, width=width) + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1343, in crop_image + # E return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill + # E num_channels, height, width = shape[-3:] + check_torch_compile_smoke = False if make_input == make_bounding_boxes else True + check_functional( + F.crop, + make_input(self.INPUT_SIZE), + **self.MINIMAL_CROP_KWARGS, + check_torch_compile_smoke=check_torch_compile_smoke, + ) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -3402,7 +3468,18 @@ def test_kernel_inplace(self, old_format, new_format): @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) def test_functional(self, old_format, new_format): - check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format) + # TODO: Disabled torch.compile check due to the error: + # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format + # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") + # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format + # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") + # E ValueError: For pure tensor inputs, `old_format` has to be passed. + check_functional( + F.convert_bounding_box_format, + make_bounding_boxes(format=old_format), + new_format=new_format, + check_torch_compile_smoke=False, + ) @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) @pytest.mark.parametrize("format_type", ["enum", "str"]) @@ -3676,7 +3753,18 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - check_functional(F.pad, make_input(), padding=[1]) + # TODO: Remove this when fixed + # E AssertionError: + # E + # E from user code: + # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1104, in resume_in_pad + # E return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode) + # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1154, in pad_image + # E return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill + # E num_channels, height, width = shape[-3:] + check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True + check_functional(F.pad, make_input(), padding=[1], check_torch_compile_smoke=check_torch_compile_smoke) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -4327,7 +4415,13 @@ def test_kernel(self, format, dtype, device): @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional(self, format): - check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format)) + # TODO: Disabled torch.compile check due to the error: + # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format + # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") + # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format + # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") + # E ValueError: For pure tensor inputs, `old_format` has to be passed. + check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format), check_torch_compile_smoke=False) def test_errors(self): input_tv_tensor = make_bounding_boxes() From d18565643598dfeb419e96d009052e17e21b790e Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Wed, 8 Nov 2023 04:21:33 -0600 Subject: [PATCH 7/8] Fixed failing compiled execs --- test/test_transforms_v2.py | 115 +++-------------------------- torchvision/tv_tensors/__init__.py | 4 + 2 files changed, 14 insertions(+), 105 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4e3c6bf49dc..7f19b923559 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -193,11 +193,10 @@ def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): return functional_compiled = torch.compile(functional) - functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs) + functional_compiled(input, *args, **kwargs) - explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs) - # TODO: Set expected values to 1, 0 once fixed the graph break related to function registration - assert explanation.graph_count in (1, 2) + explanation = torch._dynamo.explain(functional_compiled)(input, *args, **kwargs) + # TODO: Set expected value to 0 once fixed the graph break related to function registration assert explanation.graph_break_count in (0, 1) @@ -1162,22 +1161,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py:1573: in UNPACK_SEQUENCE - # assert len(val) == inst.argval - # E AssertionError: - # E - # E from user code: - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 392, in resume_in_affine - # E return kernel( - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 692, in affine_image - # E return _apply_grid_transform(image, grid, interpolation.value, fill=fill) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform - # E num_channels, input_height, input_width = input_shape[-3:] - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True - check_functional( - F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke - ) + check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1601,12 +1585,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # Error is the same as for TestAffine - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True - check_functional( - F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke - ) + check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2671,38 +2650,8 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # - TestElastic.test_functional[make_bounding_boxes]: - # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:410: in _fn - # return fn(*args, **kwargs) - # torchvision/transforms/v2/functional/_geometry.py:1705: in elastic - # kernel = _get_kernel(elastic, type(inpt)) - # torchvision/transforms/v2/functional/_geometry.py:1706: in resume_in_elastic - # return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) - # torchvision/transforms/v2/functional/_geometry.py:1741: in elastic_image - # raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") - # torchvision/transforms/v2/functional/_geometry.py:1741: in resume_in_elastic_image - # raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") - # E ValueError: Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2]) - # - # - TestElastic.test_functional[make_segmentation_mask]: - # E AssertionError: - # E - # E from user code: - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1706, in resume_in_elastic - # E return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1746, in elastic_image - # E output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform - # E num_channels, input_height, input_width = input_shape[-3:] - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True input = make_input() - check_functional( - F.elastic, - input, - displacement=self._make_displacement(input), - check_torch_compile_smoke=check_torch_compile_smoke, - ) + check_functional(F.elastic, input, displacement=self._make_displacement(input)) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2815,23 +2764,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # E AssertionError: - # E - # E from user code: - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1324, in resume_in_crop - # E return kernel(inpt, top=top, left=left, height=height, width=width) - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1343, in crop_image - # E return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") - # E File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill - # E num_channels, height, width = shape[-3:] - check_torch_compile_smoke = False if make_input == make_bounding_boxes else True - check_functional( - F.crop, - make_input(self.INPUT_SIZE), - **self.MINIMAL_CROP_KWARGS, - check_torch_compile_smoke=check_torch_compile_smoke, - ) + check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -3468,18 +3401,7 @@ def test_kernel_inplace(self, old_format, new_format): @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) def test_functional(self, old_format, new_format): - # TODO: Disabled torch.compile check due to the error: - # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # E ValueError: For pure tensor inputs, `old_format` has to be passed. - check_functional( - F.convert_bounding_box_format, - make_bounding_boxes(format=old_format), - new_format=new_format, - check_torch_compile_smoke=False, - ) + check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format) @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) @pytest.mark.parametrize("format_type", ["enum", "str"]) @@ -3753,18 +3675,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - # TODO: Remove this when fixed - # E AssertionError: - # E - # E from user code: - # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1104, in resume_in_pad - # E return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1154, in pad_image - # E return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) - # E File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill - # E num_channels, height, width = shape[-3:] - check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True - check_functional(F.pad, make_input(), padding=[1], check_torch_compile_smoke=check_torch_compile_smoke) + check_functional(F.pad, make_input(), padding=[1]) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -4415,13 +4326,7 @@ def test_kernel(self, format, dtype, device): @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional(self, format): - # TODO: Disabled torch.compile check due to the error: - # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format - # raise ValueError("For pure tensor inputs, `old_format` has to be passed.") - # E ValueError: For pure tensor inputs, `old_format` has to be passed. - check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format), check_torch_compile_smoke=False) + check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format)) def test_errors(self): input_tv_tensor = make_bounding_boxes() diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index d55e10e8620..2e58d9d4c6a 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -8,6 +8,10 @@ from ._video import Video +# TODO: Fix this. We skip this method as it leads to +# RecursionError: maximum recursion depth exceeded while calling a Python object +# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph +@torch.compiler.disable def wrap(wrappee, *, like, **kwargs): """[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. From 3c959d904bb41e8bad059d3c11b1260b7b576a48 Mon Sep 17 00:00:00 2001 From: vfdev-5 <vfdev.5@gmail.com> Date: Fri, 10 Nov 2023 04:41:50 -0600 Subject: [PATCH 8/8] Revert "Removed temporary fixes" This reverts commit 7f91053e219946189acb38c987e36c1bdfcb607c. --- test/test_transforms_v2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7f19b923559..9cba49440d1 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -220,7 +220,13 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_ # Skip check on Windows as torch.compile does not work on Win32 if check_torch_compile_smoke and sys.platform != "win32": - _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) + # Temporary fix to catch deprectation warning + # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged: + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) def check_functional_kernel_signature_match(functional, *, kernel, input_type): @@ -4130,7 +4136,7 @@ def test_kernel_video(self): @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) def test_functional(self, make_input): - check_functional(F.equalize, make_input()) + check_functional(F.equalize, make_input(), check_torch_compile_smoke=False) @pytest.mark.parametrize( ("kernel", "input_type"),