forked from pytorch/vision
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path_utils.py
141 lines (106 loc) · 5.35 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import torch
from torchvision import datapoints
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
# {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
# If you're wondering whether we could / should get rid of this wrapper,
# the answer is no: we want to pass pure Tensors to avoid the overhead
# of the __torch_function__ machinery. Note that this is always valid,
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__
# logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return datapoints.wrap(output, like=inpt)
return wrapper
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
def decorator(kernel):
registry[input_type] = (
_kernel_datapoint_wrapper(kernel)
if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
else kernel
)
return kernel
return decorator
def _name_to_functional(name):
import torchvision.transforms.v2.functional # noqa
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(
f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
) from None
_BUILTIN_DATAPOINT_TYPES = {
obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
}
def register_kernel(functional, datapoint_cls):
"""Decorate a kernel to register it for a functional and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details.
"""
if isinstance(functional, str):
functional = _name_to_functional(name=functional)
elif not (
callable(functional)
and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
f"but got {functional}."
)
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}."
)
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(functional, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(functional)
if not registry:
raise ValueError(f"No kernel registered for functional {functional.__name__}.")
for cls in input_type.__mro__:
if cls in registry:
return registry[cls]
elif cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt
raise TypeError(
f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
f"but got {input_type} instead."
)
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(functional, input_type):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
def wrap(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs)
container_type = type(output)
return container_type(datapoints.wrap(o, like=inpt) for o in output)
return wrapper
def decorator(kernel):
registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
return kernel
return decorator