1
+ import functools
2
+
1
3
import torch
2
4
import torch .library
3
5
6
8
7
9
from torch ._prims_common import check
8
10
9
- _meta_lib = torch .library .Library ("torchvision" , "IMPL" , "Meta" )
10
11
11
- vision = torch .ops .torchvision
12
+ @functools .lru_cache (None )
13
+ def get_meta_lib ():
14
+ return torch .library .Library ("torchvision" , "IMPL" , "Meta" )
12
15
13
16
14
- def register_meta (op ):
17
+ def register_meta (op_name , overload_name = "default" ):
15
18
def wrapper (fn ):
16
- _meta_lib .impl (op , fn )
19
+ if torchvision .extension ._has_ops ():
20
+ get_meta_lib ().impl (getattr (getattr (torch .ops .torchvision , op_name ), overload_name ), fn )
17
21
return fn
18
22
19
23
return wrapper
20
24
21
25
22
- @register_meta (vision . roi_align . default )
26
+ @register_meta (" roi_align" )
23
27
def meta_roi_align (input , rois , spatial_scale , pooled_height , pooled_width , sampling_ratio , aligned ):
24
28
check (rois .size (1 ) == 5 , lambda : "rois must have shape as Tensor[K, 5]" )
25
29
check (
@@ -34,7 +38,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
34
38
return input .new_empty ((num_rois , channels , pooled_height , pooled_width ))
35
39
36
40
37
- @register_meta (vision . _roi_align_backward . default )
41
+ @register_meta (" _roi_align_backward" )
38
42
def meta_roi_align_backward (
39
43
grad , rois , spatial_scale , pooled_height , pooled_width , batch_size , channels , height , width , sampling_ratio , aligned
40
44
):
0 commit comments