forked from pytorch/vision
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path_container.py
182 lines (134 loc) · 6.11 KB
/
_container.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
from torch import nn
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import Transform
class Compose(Transform):
"""[BETA] Composes several transforms together.
.. v2betastatus:: Compose transform
This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
elif not transforms:
raise ValueError("Pass at least one transform")
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
for transform in self.transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomApply(Transform):
"""[BETA] Apply randomly a list of transformations with a given probability.
.. v2betastatus:: RandomApply transform
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (float): probability of applying the list of transforms
"""
_v1_transform_cls = _transforms.RandomApply
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
super().__init__()
if not isinstance(transforms, (Sequence, nn.ModuleList)):
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
self.transforms = transforms
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
if torch.rand(1) >= self.p:
return inputs if needs_unpacking else inputs[0]
for transform in self.transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomChoice(Transform):
"""[BETA] Apply single transformation randomly picked from a list.
.. v2betastatus:: RandomChoice transform
This transform does not support torchscript.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (list of floats or None, optional): probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: Optional[List[float]] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
super().__init__()
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.p), 1))
transform = self.transforms[idx]
return transform(*inputs)
class RandomOrder(Transform):
"""[BETA] Apply a list of transformations in a random order.
.. v2betastatus:: RandomOrder transform
This transform does not support torchscript.
Args:
transforms (sequence or torch.nn.Module): list of transformations
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__()
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs