Skip to content

Commit b2a965d

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Force use of torch.compile on deterministic roi_align implementation (#8436)
Summary: Signed-off-by: Edward Z. Yang <ezyang@meta.com> Reviewed By: vmoens Differential Revision: D58283855 fbshipit-source-id: 914a91877c193b38f29af450a5935dd1ab5b20d7 Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent f3f318c commit b2a965d

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

test/test_ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
1515
from PIL import Image
1616
from torch import nn, Tensor
17+
from torch._dynamo.utils import is_compile_supported
1718
from torch.autograd import gradcheck
1819
from torch.nn.modules.utils import _pair
1920
from torchvision import models, ops
@@ -529,6 +530,10 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
529530
def test_backward(self, seed, device, contiguous, deterministic):
530531
if deterministic and device == "cpu":
531532
pytest.skip("cpu is always deterministic, don't retest")
533+
if deterministic and device == "mps":
534+
pytest.skip("no deterministic implementation for mps")
535+
if deterministic and not is_compile_supported(device):
536+
pytest.skip("deterministic implementation only if torch.compile supported")
532537
super().test_backward(seed, device, contiguous, deterministic)
533538

534539
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):

torchvision/ops/roi_align.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import functools
12
from typing import List, Union
23

34
import torch
45
import torch._dynamo
56
import torch.fx
67
from torch import nn, Tensor
8+
from torch._dynamo.utils import is_compile_supported
79
from torch.jit.annotations import BroadcastingList2
810
from torch.nn.modules.utils import _pair
911
from torchvision.extension import _assert_has_ops, _has_ops
@@ -12,6 +14,24 @@
1214
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
1315

1416

17+
def lazy_compile(**compile_kwargs):
18+
"""Lazily wrap a function with torch.compile on the first call
19+
20+
This avoids eagerly importing dynamo.
21+
"""
22+
23+
def decorate_fn(fn):
24+
@functools.wraps(fn)
25+
def compile_hook(*args, **kwargs):
26+
compiled_fn = torch.compile(fn, **compile_kwargs)
27+
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
28+
return compiled_fn(*args, **kwargs)
29+
30+
return compile_hook
31+
32+
return decorate_fn
33+
34+
1535
# NB: all inputs are tensors
1636
def _bilinear_interpolate(
1737
input, # [N, C, H, W]
@@ -86,15 +106,13 @@ def maybe_cast(tensor):
86106
return tensor
87107

88108

89-
# This is a slow but pure Python and differentiable implementation of
90-
# roi_align. It potentially is a good basis for Inductor compilation
91-
# (but I have not benchmarked it) but today it is solely used for the
92-
# fact that its backwards can be implemented deterministically,
93-
# which is needed for the PT2 benchmark suite.
94-
#
109+
# This is a pure Python and differentiable implementation of roi_align. When
110+
# run in eager mode, it uses a lot of memory, but when compiled it has
111+
# acceptable memory usage. The main point of this implementation is that
112+
# its backwards is deterministic.
95113
# It is transcribed directly off of the roi_align CUDA kernel, see
96114
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
97-
@torch._dynamo.allow_in_graph
115+
@lazy_compile(dynamic=True)
98116
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
99117
orig_dtype = input.dtype
100118

@@ -232,7 +250,9 @@ def roi_align(
232250
if not isinstance(rois, torch.Tensor):
233251
rois = convert_boxes_to_roi_format(rois)
234252
if not torch.jit.is_scripting():
235-
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
253+
if (
254+
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
255+
) and is_compile_supported(input.device.type):
236256
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
237257
_assert_has_ops()
238258
return torch.ops.torchvision.roi_align(

0 commit comments

Comments
 (0)