-
Notifications
You must be signed in to change notification settings - Fork 7k
Force use of torch.compile on deterministic roi_align implementation #8436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
3e2bf5d
4237c18
0da86f1
d6e3353
0d8c510
31a78a9
9e56286
ee25749
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,9 +1,11 @@ | ||||||
import functools | ||||||
from typing import List, Union | ||||||
|
||||||
import torch | ||||||
import torch._dynamo | ||||||
import torch.fx | ||||||
from torch import nn, Tensor | ||||||
from torch._dynamo.utils import is_compile_supported | ||||||
from torch.jit.annotations import BroadcastingList2 | ||||||
from torch.nn.modules.utils import _pair | ||||||
from torchvision.extension import _assert_has_ops, _has_ops | ||||||
|
@@ -12,6 +14,24 @@ | |||||
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format | ||||||
|
||||||
|
||||||
def lazy_compile(**compile_kwargs): | ||||||
"""Lazily wrap a function with torch.compile on the first call | ||||||
|
||||||
This avoids eagerly importing dynamo. | ||||||
""" | ||||||
|
||||||
def decorate_fn(fn): | ||||||
@functools.wraps(fn) | ||||||
def compile_hook(*args, **kwargs): | ||||||
compiled_fn = torch.compile(fn, **compile_kwargs) | ||||||
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) | ||||||
return compiled_fn(*args, **kwargs) | ||||||
|
||||||
return compile_hook | ||||||
|
||||||
return decorate_fn | ||||||
|
||||||
|
||||||
# NB: all inputs are tensors | ||||||
def _bilinear_interpolate( | ||||||
input, # [N, C, H, W] | ||||||
|
@@ -86,15 +106,13 @@ def maybe_cast(tensor): | |||||
return tensor | ||||||
|
||||||
|
||||||
# This is a slow but pure Python and differentiable implementation of | ||||||
# roi_align. It potentially is a good basis for Inductor compilation | ||||||
# (but I have not benchmarked it) but today it is solely used for the | ||||||
# fact that its backwards can be implemented deterministically, | ||||||
# which is needed for the PT2 benchmark suite. | ||||||
# | ||||||
# This is a pure Python and differentiable implementation of roi_align. When | ||||||
# run in eager mode, it uses a lot of memory, but when compiled it has | ||||||
# acceptable memory usage. The main point of this implementation is that | ||||||
# its backwards is deterministic. | ||||||
# It is transcribed directly off of the roi_align CUDA kernel, see | ||||||
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 | ||||||
@torch._dynamo.allow_in_graph | ||||||
@lazy_compile(dynamic=True) | ||||||
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): | ||||||
orig_dtype = input.dtype | ||||||
|
||||||
|
@@ -232,7 +250,9 @@ def roi_align( | |||||
if not isinstance(rois, torch.Tensor): | ||||||
rois = convert_boxes_to_roi_format(rois) | ||||||
if not torch.jit.is_scripting(): | ||||||
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): | ||||||
if ( | ||||||
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just remove the mps part here since you mentioned MPS doesn't even work with torch.compile?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opted to keep it around, because it was explicitly added by @qqaatw, but I don't really mind either way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply! I'm ok with either way that is best for the development. From the mentioned issue it seems only relevant to CUDA, is MPS similarly memory hungry with deterministic algorithm? |
||||||
) and is_compile_supported(input.device.type): | ||||||
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) | ||||||
_assert_has_ops() | ||||||
return torch.ops.torchvision.roi_align( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I understanding this correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Even with torch.compile at top level it isn't compiled until you call it the first time. But importing dynamo has undesirable side effects for eager mode only users so it's best not to do it.