Skip to content

Commit a9a726a

Browse files
authored
Make v2 transforms authoring public (#8787)
1 parent 48f01de commit a9a726a

18 files changed

+260
-148
lines changed

docs/source/transforms.rst

+9
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,20 @@ are combining pairs of images together. These can be used after the dataloader
508508
Developer tools
509509
^^^^^^^^^^^^^^^
510510

511+
.. autosummary::
512+
:toctree: generated/
513+
:template: class.rst
514+
515+
v2.Transform
516+
511517
.. autosummary::
512518
:toctree: generated/
513519
:template: function.rst
514520

515521
v2.functional.register_kernel
522+
v2.query_size
523+
v2.query_chw
524+
v2.get_bounding_boxes
516525

517526

518527
V1 API Reference

gallery/transforms/plot_custom_transforms.py

+98-19
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
"""
1313

1414
# %%
15+
from typing import Any, Dict, List
16+
1517
import torch
1618
from torchvision import tv_tensors
1719
from torchvision.transforms import v2
@@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
8991
# A key feature of the builtin Torchvision V2 transforms is that they can accept
9092
# arbitrary input structure and return the same structure as output (with
9193
# transformed entries). For example, transforms can accept a single image, or a
92-
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
94+
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's
95+
# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`:
9396

9497
structured_input = {
9598
"img": img,
9699
"annotations": (bboxes, label),
97-
"something_that_will_be_ignored": (1, "hello")
100+
"something that will be ignored": (1, "hello"),
101+
"another tensor that is ignored": torch.arange(10),
98102
}
99103
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
100104

101105
assert isinstance(structured_output, dict)
102-
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
106+
assert structured_output["something that will be ignored"] == (1, "hello")
107+
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
108+
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
109+
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
110+
111+
# %%
112+
# Basics: override the `transform()` method
113+
# -----------------------------------------
114+
#
115+
# In order to support arbitrary inputs in your custom transform, you will need
116+
# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
117+
# `.transform()` method (not the `forward()` method!). Below is a basic example:
118+
119+
120+
class MyCustomTransform(v2.Transform):
121+
def transform(self, inpt: Any, params: Dict[str, Any]):
122+
if type(inpt) == torch.Tensor:
123+
print(f"I'm transforming an image of shape {inpt.shape}")
124+
return inpt + 1 # dummy transformation
125+
elif isinstance(inpt, tv_tensors.BoundingBoxes):
126+
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
127+
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation
128+
129+
130+
my_custom_transform = MyCustomTransform()
131+
structured_output = my_custom_transform(structured_input)
132+
133+
assert isinstance(structured_output, dict)
134+
assert structured_output["something that will be ignored"] == (1, "hello")
135+
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
136+
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
103137
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
104138

105139
# %%
106-
# If you want to reproduce this behavior in your own transform, we invite you to
107-
# look at our `code
108-
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
109-
# and adapt it to your needs.
110-
#
111-
# In brief, the core logic is to unpack the input into a flat list using `pytree
112-
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
113-
# then transform only the entries that can be transformed (the decision is made
114-
# based on the **class** of the entries, as all TVTensors are
115-
# tensor-subclasses) plus some custom logic that is out of score here - check the
116-
# code for details. The (potentially transformed) entries are then repacked and
117-
# returned, in the same structure as the input.
118-
#
119-
# We do not provide public dev-facing tools to achieve that at this time, but if
120-
# this is something that would be valuable to you, please let us know by opening
121-
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
140+
# An important thing to note is that when we call ``my_custom_transform`` on
141+
# ``structured_input``, the input is flattened and then each individual part is
142+
# passed to ``transform()``. That is, ``transform()``` receives the input image,
143+
# then the bounding boxes, etc. Within ``transform()``, you can decide how to
144+
# transform each input, based on their type.
145+
#
146+
# If you're curious why the other tensor (``torch.arange()``) didn't get passed
147+
# to ``transform()``, see :ref:`this note <passthrough_heuristic>` for more
148+
# details.
149+
#
150+
# Advanced: The ``make_params()`` method
151+
# --------------------------------------
152+
#
153+
# The ``make_params()`` method is called internally before calling
154+
# ``transform()`` on each input. This is typically useful to generate random
155+
# parameter values. In the example below, we use it to randomly apply the
156+
# transformation with a probability of 0.5
157+
158+
159+
class MyRandomTransform(MyCustomTransform):
160+
def __init__(self, p=0.5):
161+
self.p = p
162+
super().__init__()
163+
164+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
165+
apply_transform = (torch.rand(size=(1,)) < self.p).item()
166+
params = dict(apply_transform=apply_transform)
167+
return params
168+
169+
def transform(self, inpt: Any, params: Dict[str, Any]):
170+
if not params["apply_transform"]:
171+
print("Not transforming anything!")
172+
return inpt
173+
else:
174+
return super().transform(inpt, params)
175+
176+
177+
my_random_transform = MyRandomTransform()
178+
179+
torch.manual_seed(0)
180+
_ = my_random_transform(structured_input) # transforms
181+
_ = my_random_transform(structured_input) # doesn't transform
182+
183+
# %%
184+
#
185+
# .. note::
186+
#
187+
# It's important for such random parameter generation to happen within
188+
# ``make_params()`` and not within ``transform()``, so that for a given
189+
# transform call, the same RNG applies to all the inputs in the same way. If
190+
# we were to perform the RNG within ``transform()``, we would risk e.g.
191+
# transforming the image while *not* transforming the bounding boxes.
192+
#
193+
# The ``make_params()`` method takes the list of all the inputs as parameter
194+
# (each of the elements in this list will later be pased to ``transform()``).
195+
# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input,
196+
# using :func:`~torchvision.transforms.v2.query_chw` or
197+
# :func:`~torchvision.transforms.v2.query_size`.
198+
#
199+
# ``make_params()`` should return a dict (or actually, anything you want) that
200+
# will then be passed to ``transform()``.

references/segmentation/v2_extras.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ def __init__(self, size, fill=0):
1010
self.size = size
1111
self.fill = v2._utils._setup_fill_arg(fill)
1212

13-
def _get_params(self, sample):
13+
def make_params(self, sample):
1414
_, height, width = v2._utils.query_chw(sample)
1515
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
1616
needs_padding = any(padding)
1717
return dict(padding=padding, needs_padding=needs_padding)
1818

19-
def _transform(self, inpt, params):
19+
def transform(self, inpt, params):
2020
if not params["needs_padding"]:
2121
return inpt
2222

test/test_prototype_transforms.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test__copy_paste(self, label_type):
159159

160160

161161
class TestFixedSizeCrop:
162-
def test__get_params(self, mocker):
162+
def test_make_params(self, mocker):
163163
crop_size = (7, 7)
164164
batch_shape = (10,)
165165
canvas_size = (11, 5)
@@ -170,7 +170,7 @@ def test__get_params(self, mocker):
170170
make_image(size=canvas_size, color_space="RGB"),
171171
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]),
172172
]
173-
params = transform._get_params(flat_inputs)
173+
params = transform.make_params(flat_inputs)
174174

175175
assert params["needs_crop"]
176176
assert params["height"] <= crop_size[0]
@@ -191,7 +191,7 @@ def test__transform_culling(self, mocker):
191191

192192
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
193193
mocker.patch(
194-
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
194+
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
195195
return_value=dict(
196196
needs_crop=True,
197197
top=0,
@@ -229,7 +229,7 @@ def test__transform_bounding_boxes_clamping(self, mocker):
229229
canvas_size = (10, 10)
230230

231231
mocker.patch(
232-
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
232+
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
233233
return_value=dict(
234234
needs_crop=True,
235235
top=0,

0 commit comments

Comments
 (0)