Skip to content

Force use of torch.compile on deterministic roi_align implementation #8436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import models, ops
Expand Down Expand Up @@ -529,6 +530,10 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
def test_backward(self, seed, device, contiguous, deterministic):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
if deterministic and device == "mps":
pytest.skip("no deterministic implementation for mps")
if deterministic and not is_compile_supported(device):
pytest.skip("deterministic implementation only if torch.compile supported")
super().test_backward(seed, device, contiguous, deterministic)

def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
Expand Down
36 changes: 28 additions & 8 deletions torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
from typing import List, Union

import torch
import torch._dynamo
import torch.fx
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops, _has_ops
Expand All @@ -12,6 +14,24 @@
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format


def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call

This avoids eagerly importing dynamo.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding this correctly?

Suggested change
This avoids eagerly importing dynamo.
This avoids eagerly compiling a function at import time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. Even with torch.compile at top level it isn't compiled until you call it the first time. But importing dynamo has undesirable side effects for eager mode only users so it's best not to do it.

"""

def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)

return compile_hook

return decorate_fn


# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
Expand Down Expand Up @@ -86,15 +106,13 @@ def maybe_cast(tensor):
return tensor


# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# This is a pure Python and differentiable implementation of roi_align. When
# run in eager mode, it uses a lot of memory, but when compiled it has
# acceptable memory usage. The main point of this implementation is that
# its backwards is deterministic.
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
@lazy_compile(dynamic=True)
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype

Expand Down Expand Up @@ -232,7 +250,9 @@ def roi_align(
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
if (
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just remove the mps part here since you mentioned MPS doesn't even work with torch.compile?

Suggested change
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to keep it around, because it was explicitly added by @qqaatw, but I don't really mind either way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply! I'm ok with either way that is best for the development. From the mentioned issue it seems only relevant to CUDA, is MPS similarly memory hungry with deterministic algorithm?

) and is_compile_supported(input.device.type):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align(
Expand Down
Loading