|
| 1 | +import functools |
1 | 2 | from typing import List, Union
|
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 | import torch._dynamo
|
5 | 6 | import torch.fx
|
6 | 7 | from torch import nn, Tensor
|
| 8 | +from torch._dynamo.utils import is_compile_supported |
7 | 9 | from torch.jit.annotations import BroadcastingList2
|
8 | 10 | from torch.nn.modules.utils import _pair
|
9 | 11 | from torchvision.extension import _assert_has_ops, _has_ops
|
|
12 | 14 | from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
|
13 | 15 |
|
14 | 16 |
|
| 17 | +def lazy_compile(**compile_kwargs): |
| 18 | + """Lazily wrap a function with torch.compile on the first call |
| 19 | +
|
| 20 | + This avoids eagerly importing dynamo. |
| 21 | + """ |
| 22 | + |
| 23 | + def decorate_fn(fn): |
| 24 | + @functools.wraps(fn) |
| 25 | + def compile_hook(*args, **kwargs): |
| 26 | + compiled_fn = torch.compile(fn, **compile_kwargs) |
| 27 | + globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) |
| 28 | + return compiled_fn(*args, **kwargs) |
| 29 | + |
| 30 | + return compile_hook |
| 31 | + |
| 32 | + return decorate_fn |
| 33 | + |
| 34 | + |
15 | 35 | # NB: all inputs are tensors
|
16 | 36 | def _bilinear_interpolate(
|
17 | 37 | input, # [N, C, H, W]
|
@@ -86,15 +106,13 @@ def maybe_cast(tensor):
|
86 | 106 | return tensor
|
87 | 107 |
|
88 | 108 |
|
89 |
| -# This is a slow but pure Python and differentiable implementation of |
90 |
| -# roi_align. It potentially is a good basis for Inductor compilation |
91 |
| -# (but I have not benchmarked it) but today it is solely used for the |
92 |
| -# fact that its backwards can be implemented deterministically, |
93 |
| -# which is needed for the PT2 benchmark suite. |
94 |
| -# |
| 109 | +# This is a pure Python and differentiable implementation of roi_align. When |
| 110 | +# run in eager mode, it uses a lot of memory, but when compiled it has |
| 111 | +# acceptable memory usage. The main point of this implementation is that |
| 112 | +# its backwards is deterministic. |
95 | 113 | # It is transcribed directly off of the roi_align CUDA kernel, see
|
96 | 114 | # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
|
97 |
| -@torch._dynamo.allow_in_graph |
| 115 | +@lazy_compile(dynamic=True) |
98 | 116 | def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
99 | 117 | orig_dtype = input.dtype
|
100 | 118 |
|
@@ -232,7 +250,9 @@ def roi_align(
|
232 | 250 | if not isinstance(rois, torch.Tensor):
|
233 | 251 | rois = convert_boxes_to_roi_format(rois)
|
234 | 252 | if not torch.jit.is_scripting():
|
235 |
| - if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): |
| 253 | + if ( |
| 254 | + not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)) |
| 255 | + ) and is_compile_supported(input.device.type): |
236 | 256 | return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
|
237 | 257 | _assert_has_ops()
|
238 | 258 | return torch.ops.torchvision.roi_align(
|
|
0 commit comments