From d4f16cb3fe5093f22ec834a7998bd49800084e1c Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Thu, 2 Nov 2023 12:02:00 -0500
Subject: [PATCH 1/8] Added torch compile checks for functional ops

---
 test/test_transforms_v2.py | 31 +++++++++++++++++++++++++++----
 1 file changed, 27 insertions(+), 4 deletions(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 4f8d0027bd6..06f10fcca71 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -186,7 +186,21 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
         functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
 
 
-def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
+def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs):
+    """Checks if the functional can be torch compiled and the compiled version can be called without error."""
+    if not isinstance(input, tv_tensors.Image):
+        return
+
+    functional_compiled = torch.compile(functional)
+    functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs)
+
+    explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs)
+    # TODO: Set expected values to 2, 1 once fixed the graph break related to function registration
+    assert explanation.graph_count == 2
+    assert explanation.graph_break_count == 1
+
+
+def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs):
     unknown_input = object()
     with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
         functional(unknown_input, *args, **kwargs)
@@ -204,6 +218,9 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar
     if check_scripted_smoke:
         _check_functional_scripted_smoke(functional, input, *args, **kwargs)
 
+    if check_torch_compile_smoke:
+        _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
+
 
 def check_functional_kernel_signature_match(functional, *, kernel, input_type):
     """Checks if the signature of the functional matches the kernel signature."""
@@ -656,6 +673,7 @@ def test_functional(self, size, make_input):
             size=size,
             antialias=True,
             check_scripted_smoke=not isinstance(size, int),
+            check_torch_compile_smoke=False,
         )
 
     @pytest.mark.parametrize(
@@ -3469,7 +3487,12 @@ def test_kernel(self, kernel, make_input):
     )
     def test_functional(self, make_input):
         check_functional(
-            F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True
+            F.resized_crop,
+            make_input(self.INPUT_SIZE),
+            **self.CROP_KWARGS,
+            size=self.OUTPUT_SIZE,
+            antialias=True,
+            check_torch_compile_smoke=False,
         )
 
     @pytest.mark.parametrize(
@@ -3949,7 +3972,7 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS)
+        check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS, check_torch_compile_smoke=False)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -4106,7 +4129,7 @@ def test_kernel_video(self):
 
     @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
     def test_functional(self, make_input):
-        check_functional(F.equalize, make_input())
+        check_functional(F.equalize, make_input(), check_torch_compile_smoke=False)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),

From d5ca7b297ee5afc881a341f4a57a52f5ff14dd21 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Fri, 3 Nov 2023 12:18:13 -0500
Subject: [PATCH 2/8] Fixed failing tests

---
 test/test_transforms_v2.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 06f10fcca71..afafeb65e71 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -7,6 +7,7 @@
 import pickle
 import random
 import re
+import sys
 from copy import deepcopy
 from pathlib import Path
 from unittest import mock
@@ -195,9 +196,9 @@ def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs):
     functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs)
 
     explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs)
-    # TODO: Set expected values to 2, 1 once fixed the graph break related to function registration
-    assert explanation.graph_count == 2
-    assert explanation.graph_break_count == 1
+    # TODO: Set expected values to 1, 0 once fixed the graph break related to function registration
+    assert explanation.graph_count in (1, 2)
+    assert explanation.graph_break_count in (0, 1)
 
 
 def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs):
@@ -218,7 +219,8 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_
     if check_scripted_smoke:
         _check_functional_scripted_smoke(functional, input, *args, **kwargs)
 
-    if check_torch_compile_smoke:
+    # Skip check on Windows as torch.compile does not work on Win32
+    if check_torch_compile_smoke and sys.platform != "win32":
         _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
 
 
@@ -1982,6 +1984,7 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
             make_input(dtype=input_dtype, device=device),
             dtype=output_dtype,
             scale=scale,
+
         )
 
     @pytest.mark.parametrize(

From 6fce5da7da61925f8847b275bf6e4b229d0d9333 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Fri, 3 Nov 2023 12:24:26 -0500
Subject: [PATCH 3/8] Fixed lint

---
 test/test_transforms_v2.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index afafeb65e71..dac560d6b41 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -1984,7 +1984,6 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
             make_input(dtype=input_dtype, device=device),
             dtype=output_dtype,
             scale=scale,
-
         )
 
     @pytest.mark.parametrize(

From 829a2d3da334185026a000e176db20a9552f4f45 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Mon, 6 Nov 2023 07:10:33 -0600
Subject: [PATCH 4/8] Fixed depr warning problem

---
 test/test_transforms_v2.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index dac560d6b41..5ed2fe76ce0 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -221,7 +221,13 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_
 
     # Skip check on Windows as torch.compile does not work on Win32
     if check_torch_compile_smoke and sys.platform != "win32":
-        _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
+        # Temporary fix to catch deprectation warning
+        # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged:
+        import warnings
+
+        with warnings.catch_warnings():
+            warnings.filterwarnings("ignore", category=DeprecationWarning)
+            _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
 
 
 def check_functional_kernel_signature_match(functional, *, kernel, input_type):

From 7f91053e219946189acb38c987e36c1bdfcb607c Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Tue, 7 Nov 2023 04:38:22 -0600
Subject: [PATCH 5/8] Removed temporary fixes

---
 test/test_transforms_v2.py | 10 ++--------
 1 file changed, 2 insertions(+), 8 deletions(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 5ed2fe76ce0..38136152f39 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -221,13 +221,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_
 
     # Skip check on Windows as torch.compile does not work on Win32
     if check_torch_compile_smoke and sys.platform != "win32":
-        # Temporary fix to catch deprectation warning
-        # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged:
-        import warnings
-
-        with warnings.catch_warnings():
-            warnings.filterwarnings("ignore", category=DeprecationWarning)
-            _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
+        _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
 
 
 def check_functional_kernel_signature_match(functional, *, kernel, input_type):
@@ -4137,7 +4131,7 @@ def test_kernel_video(self):
 
     @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
     def test_functional(self, make_input):
-        check_functional(F.equalize, make_input(), check_torch_compile_smoke=False)
+        check_functional(F.equalize, make_input())
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),

From 84100215138a7d78a5396bec1e688bd27f7e53f7 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Wed, 8 Nov 2023 02:59:51 -0600
Subject: [PATCH 6/8] Enabled torch compile functional tests for boxes, masks
 and video Annotated exceptions with encountered errors

---
 test/test_transforms_v2.py | 110 ++++++++++++++++++++++++++++++++++---
 1 file changed, 102 insertions(+), 8 deletions(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 38136152f39..4e3c6bf49dc 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -189,7 +189,7 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
 
 def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs):
     """Checks if the functional can be torch compiled and the compiled version can be called without error."""
-    if not isinstance(input, tv_tensors.Image):
+    if not isinstance(input, torch.Tensor):
         return
 
     functional_compiled = torch.compile(functional)
@@ -1162,7 +1162,22 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
+        # TODO: Remove this when fixed
+        # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py:1573: in UNPACK_SEQUENCE
+        #     assert len(val) == inst.argval
+        # E   AssertionError:
+        # E
+        # E   from user code:
+        # E      File "vision/torchvision/transforms/v2/functional/_geometry.py", line 392, in resume_in_affine
+        # E       return kernel(
+        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 692, in affine_image
+        # E       return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
+        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform
+        # E       num_channels, input_height, input_width = input_shape[-3:]
+        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
+        check_functional(
+            F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke
+        )
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -1586,7 +1601,12 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
+        # TODO: Remove this when fixed
+        # Error is the same as for TestAffine
+        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
+        check_functional(
+            F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke
+        )
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -2651,8 +2671,38 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
+        # TODO: Remove this when fixed
+        # - TestElastic.test_functional[make_bounding_boxes]:
+        # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:410: in _fn
+        #     return fn(*args, **kwargs)
+        # torchvision/transforms/v2/functional/_geometry.py:1705: in elastic
+        #     kernel = _get_kernel(elastic, type(inpt))
+        # torchvision/transforms/v2/functional/_geometry.py:1706: in resume_in_elastic
+        #     return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
+        # torchvision/transforms/v2/functional/_geometry.py:1741: in elastic_image
+        #     raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
+        # torchvision/transforms/v2/functional/_geometry.py:1741: in resume_in_elastic_image
+        #     raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
+        # E   ValueError: Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2])
+        #
+        # - TestElastic.test_functional[make_segmentation_mask]:
+        # E   AssertionError:
+        # E
+        # E   from user code:
+        # E      File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1706, in resume_in_elastic
+        # E       return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
+        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1746, in elastic_image
+        # E       output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
+        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform
+        # E       num_channels, input_height, input_width = input_shape[-3:]
+        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
         input = make_input()
-        check_functional(F.elastic, input, displacement=self._make_displacement(input))
+        check_functional(
+            F.elastic,
+            input,
+            displacement=self._make_displacement(input),
+            check_torch_compile_smoke=check_torch_compile_smoke,
+        )
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -2765,7 +2815,23 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
+        # TODO: Remove this when fixed
+        # E   AssertionError:
+        # E
+        # E   from user code:
+        # E      File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1324, in resume_in_crop
+        # E       return kernel(inpt, top=top, left=left, height=height, width=width)
+        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1343, in crop_image
+        # E       return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
+        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill
+        # E       num_channels, height, width = shape[-3:]
+        check_torch_compile_smoke = False if make_input == make_bounding_boxes else True
+        check_functional(
+            F.crop,
+            make_input(self.INPUT_SIZE),
+            **self.MINIMAL_CROP_KWARGS,
+            check_torch_compile_smoke=check_torch_compile_smoke,
+        )
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -3402,7 +3468,18 @@ def test_kernel_inplace(self, old_format, new_format):
 
     @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
     def test_functional(self, old_format, new_format):
-        check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format)
+        # TODO: Disabled torch.compile check due to the error:
+        # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format
+        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
+        # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format
+        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
+        # E   ValueError: For pure tensor inputs, `old_format` has to be passed.
+        check_functional(
+            F.convert_bounding_box_format,
+            make_bounding_boxes(format=old_format),
+            new_format=new_format,
+            check_torch_compile_smoke=False,
+        )
 
     @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
     @pytest.mark.parametrize("format_type", ["enum", "str"])
@@ -3676,7 +3753,18 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        check_functional(F.pad, make_input(), padding=[1])
+        # TODO: Remove this when fixed
+        # E   AssertionError:
+        # E
+        # E   from user code:
+        # E      File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1104, in resume_in_pad
+        # E       return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
+        # E     File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1154, in pad_image
+        # E       return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
+        # E     File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill
+        # E       num_channels, height, width = shape[-3:]
+        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
+        check_functional(F.pad, make_input(), padding=[1], check_torch_compile_smoke=check_torch_compile_smoke)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -4327,7 +4415,13 @@ def test_kernel(self, format, dtype, device):
 
     @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
     def test_functional(self, format):
-        check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
+        # TODO: Disabled torch.compile check due to the error:
+        # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format
+        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
+        # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format
+        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
+        # E   ValueError: For pure tensor inputs, `old_format` has to be passed.
+        check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format), check_torch_compile_smoke=False)
 
     def test_errors(self):
         input_tv_tensor = make_bounding_boxes()

From d18565643598dfeb419e96d009052e17e21b790e Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Wed, 8 Nov 2023 04:21:33 -0600
Subject: [PATCH 7/8] Fixed failing compiled execs

---
 test/test_transforms_v2.py         | 115 +++--------------------------
 torchvision/tv_tensors/__init__.py |   4 +
 2 files changed, 14 insertions(+), 105 deletions(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 4e3c6bf49dc..7f19b923559 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -193,11 +193,10 @@ def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs):
         return
 
     functional_compiled = torch.compile(functional)
-    functional_compiled(input.as_subclass(torch.Tensor), *args, **kwargs)
+    functional_compiled(input, *args, **kwargs)
 
-    explanation = torch._dynamo.explain(functional_compiled)(input.as_subclass(torch.Tensor), *args, **kwargs)
-    # TODO: Set expected values to 1, 0 once fixed the graph break related to function registration
-    assert explanation.graph_count in (1, 2)
+    explanation = torch._dynamo.explain(functional_compiled)(input, *args, **kwargs)
+    # TODO: Set expected value to 0 once fixed the graph break related to function registration
     assert explanation.graph_break_count in (0, 1)
 
 
@@ -1162,22 +1161,7 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        # TODO: Remove this when fixed
-        # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py:1573: in UNPACK_SEQUENCE
-        #     assert len(val) == inst.argval
-        # E   AssertionError:
-        # E
-        # E   from user code:
-        # E      File "vision/torchvision/transforms/v2/functional/_geometry.py", line 392, in resume_in_affine
-        # E       return kernel(
-        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 692, in affine_image
-        # E       return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
-        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform
-        # E       num_channels, input_height, input_width = input_shape[-3:]
-        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
-        check_functional(
-            F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke
-        )
+        check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -1601,12 +1585,7 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        # TODO: Remove this when fixed
-        # Error is the same as for TestAffine
-        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
-        check_functional(
-            F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS, check_torch_compile_smoke=check_torch_compile_smoke
-        )
+        check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -2671,38 +2650,8 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        # TODO: Remove this when fixed
-        # - TestElastic.test_functional[make_bounding_boxes]:
-        # /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:410: in _fn
-        #     return fn(*args, **kwargs)
-        # torchvision/transforms/v2/functional/_geometry.py:1705: in elastic
-        #     kernel = _get_kernel(elastic, type(inpt))
-        # torchvision/transforms/v2/functional/_geometry.py:1706: in resume_in_elastic
-        #     return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
-        # torchvision/transforms/v2/functional/_geometry.py:1741: in elastic_image
-        #     raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
-        # torchvision/transforms/v2/functional/_geometry.py:1741: in resume_in_elastic_image
-        #     raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
-        # E   ValueError: Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2])
-        #
-        # - TestElastic.test_functional[make_segmentation_mask]:
-        # E   AssertionError:
-        # E
-        # E   from user code:
-        # E      File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1706, in resume_in_elastic
-        # E       return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
-        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1746, in elastic_image
-        # E       output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
-        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 556, in _apply_grid_transform
-        # E       num_channels, input_height, input_width = input_shape[-3:]
-        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
         input = make_input()
-        check_functional(
-            F.elastic,
-            input,
-            displacement=self._make_displacement(input),
-            check_torch_compile_smoke=check_torch_compile_smoke,
-        )
+        check_functional(F.elastic, input, displacement=self._make_displacement(input))
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -2815,23 +2764,7 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        # TODO: Remove this when fixed
-        # E   AssertionError:
-        # E
-        # E   from user code:
-        # E      File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1324, in resume_in_crop
-        # E       return kernel(inpt, top=top, left=left, height=height, width=width)
-        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1343, in crop_image
-        # E       return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
-        # E     File "vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill
-        # E       num_channels, height, width = shape[-3:]
-        check_torch_compile_smoke = False if make_input == make_bounding_boxes else True
-        check_functional(
-            F.crop,
-            make_input(self.INPUT_SIZE),
-            **self.MINIMAL_CROP_KWARGS,
-            check_torch_compile_smoke=check_torch_compile_smoke,
-        )
+        check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -3468,18 +3401,7 @@ def test_kernel_inplace(self, old_format, new_format):
 
     @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
     def test_functional(self, old_format, new_format):
-        # TODO: Disabled torch.compile check due to the error:
-        # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format
-        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
-        # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format
-        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
-        # E   ValueError: For pure tensor inputs, `old_format` has to be passed.
-        check_functional(
-            F.convert_bounding_box_format,
-            make_bounding_boxes(format=old_format),
-            new_format=new_format,
-            check_torch_compile_smoke=False,
-        )
+        check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format)
 
     @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
     @pytest.mark.parametrize("format_type", ["enum", "str"])
@@ -3753,18 +3675,7 @@ def test_kernel_video(self):
         [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
     )
     def test_functional(self, make_input):
-        # TODO: Remove this when fixed
-        # E   AssertionError:
-        # E
-        # E   from user code:
-        # E      File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1104, in resume_in_pad
-        # E       return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
-        # E     File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1154, in pad_image
-        # E       return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
-        # E     File "/vision/torchvision/transforms/v2/functional/_geometry.py", line 1168, in _pad_with_scalar_fill
-        # E       num_channels, height, width = shape[-3:]
-        check_torch_compile_smoke = False if make_input in (make_bounding_boxes, make_segmentation_mask) else True
-        check_functional(F.pad, make_input(), padding=[1], check_torch_compile_smoke=check_torch_compile_smoke)
+        check_functional(F.pad, make_input(), padding=[1])
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),
@@ -4415,13 +4326,7 @@ def test_kernel(self, format, dtype, device):
 
     @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
     def test_functional(self, format):
-        # TODO: Disabled torch.compile check due to the error:
-        # torchvision/transforms/v2/functional/_meta.py:219: in convert_bounding_box_format
-        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
-        # torchvision/transforms/v2/functional/_meta.py:219: in resume_in_convert_bounding_box_format
-        #     raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
-        # E   ValueError: For pure tensor inputs, `old_format` has to be passed.
-        check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format), check_torch_compile_smoke=False)
+        check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
 
     def test_errors(self):
         input_tv_tensor = make_bounding_boxes()
diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py
index d55e10e8620..2e58d9d4c6a 100644
--- a/torchvision/tv_tensors/__init__.py
+++ b/torchvision/tv_tensors/__init__.py
@@ -8,6 +8,10 @@
 from ._video import Video
 
 
+# TODO: Fix this. We skip this method as it leads to
+# RecursionError: maximum recursion depth exceeded while calling a Python object
+# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph
+@torch.compiler.disable
 def wrap(wrappee, *, like, **kwargs):
     """[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.
 

From 3c959d904bb41e8bad059d3c11b1260b7b576a48 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5@gmail.com>
Date: Fri, 10 Nov 2023 04:41:50 -0600
Subject: [PATCH 8/8] Revert "Removed temporary fixes"

This reverts commit 7f91053e219946189acb38c987e36c1bdfcb607c.
---
 test/test_transforms_v2.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 7f19b923559..9cba49440d1 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -220,7 +220,13 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, check_
 
     # Skip check on Windows as torch.compile does not work on Win32
     if check_torch_compile_smoke and sys.platform != "win32":
-        _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
+        # Temporary fix to catch deprectation warning
+        # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged:
+        import warnings
+
+        with warnings.catch_warnings():
+            warnings.filterwarnings("ignore", category=DeprecationWarning)
+            _check_functional_torch_compile_smoke(functional, input, *args, **kwargs)
 
 
 def check_functional_kernel_signature_match(functional, *, kernel, input_type):
@@ -4130,7 +4136,7 @@ def test_kernel_video(self):
 
     @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
     def test_functional(self, make_input):
-        check_functional(F.equalize, make_input())
+        check_functional(F.equalize, make_input(), check_torch_compile_smoke=False)
 
     @pytest.mark.parametrize(
         ("kernel", "input_type"),