|
12 | 12 | """
|
13 | 13 |
|
14 | 14 | # %%
|
| 15 | +from typing import Any, Dict, List |
| 16 | + |
15 | 17 | import torch
|
16 | 18 | from torchvision import tv_tensors
|
17 | 19 | from torchvision.transforms import v2
|
@@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
|
89 | 91 | # A key feature of the builtin Torchvision V2 transforms is that they can accept
|
90 | 92 | # arbitrary input structure and return the same structure as output (with
|
91 | 93 | # transformed entries). For example, transforms can accept a single image, or a
|
92 |
| -# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: |
| 94 | +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's |
| 95 | +# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`: |
93 | 96 |
|
94 | 97 | structured_input = {
|
95 | 98 | "img": img,
|
96 | 99 | "annotations": (bboxes, label),
|
97 |
| - "something_that_will_be_ignored": (1, "hello") |
| 100 | + "something that will be ignored": (1, "hello"), |
| 101 | + "another tensor that is ignored": torch.arange(10), |
98 | 102 | }
|
99 | 103 | structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
|
100 | 104 |
|
101 | 105 | assert isinstance(structured_output, dict)
|
102 |
| -assert structured_output["something_that_will_be_ignored"] == (1, "hello") |
| 106 | +assert structured_output["something that will be ignored"] == (1, "hello") |
| 107 | +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() |
| 108 | +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") |
| 109 | +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") |
| 110 | + |
| 111 | +# %% |
| 112 | +# Basics: override the `transform()` method |
| 113 | +# ----------------------------------------- |
| 114 | +# |
| 115 | +# In order to support arbitrary inputs in your custom transform, you will need |
| 116 | +# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the |
| 117 | +# `.transform()` method (not the `forward()` method!). Below is a basic example: |
| 118 | + |
| 119 | + |
| 120 | +class MyCustomTransform(v2.Transform): |
| 121 | + def transform(self, inpt: Any, params: Dict[str, Any]): |
| 122 | + if type(inpt) == torch.Tensor: |
| 123 | + print(f"I'm transforming an image of shape {inpt.shape}") |
| 124 | + return inpt + 1 # dummy transformation |
| 125 | + elif isinstance(inpt, tv_tensors.BoundingBoxes): |
| 126 | + print(f"I'm transforming bounding boxes! {inpt.canvas_size = }") |
| 127 | + return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation |
| 128 | + |
| 129 | + |
| 130 | +my_custom_transform = MyCustomTransform() |
| 131 | +structured_output = my_custom_transform(structured_input) |
| 132 | + |
| 133 | +assert isinstance(structured_output, dict) |
| 134 | +assert structured_output["something that will be ignored"] == (1, "hello") |
| 135 | +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() |
| 136 | +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") |
103 | 137 | print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
|
104 | 138 |
|
105 | 139 | # %%
|
106 |
| -# If you want to reproduce this behavior in your own transform, we invite you to |
107 |
| -# look at our `code |
108 |
| -# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_ |
109 |
| -# and adapt it to your needs. |
110 |
| -# |
111 |
| -# In brief, the core logic is to unpack the input into a flat list using `pytree |
112 |
| -# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and |
113 |
| -# then transform only the entries that can be transformed (the decision is made |
114 |
| -# based on the **class** of the entries, as all TVTensors are |
115 |
| -# tensor-subclasses) plus some custom logic that is out of score here - check the |
116 |
| -# code for details. The (potentially transformed) entries are then repacked and |
117 |
| -# returned, in the same structure as the input. |
118 |
| -# |
119 |
| -# We do not provide public dev-facing tools to achieve that at this time, but if |
120 |
| -# this is something that would be valuable to you, please let us know by opening |
121 |
| -# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_. |
| 140 | +# An important thing to note is that when we call ``my_custom_transform`` on |
| 141 | +# ``structured_input``, the input is flattened and then each individual part is |
| 142 | +# passed to ``transform()``. That is, ``transform()``` receives the input image, |
| 143 | +# then the bounding boxes, etc. Within ``transform()``, you can decide how to |
| 144 | +# transform each input, based on their type. |
| 145 | +# |
| 146 | +# If you're curious why the other tensor (``torch.arange()``) didn't get passed |
| 147 | +# to ``transform()``, see :ref:`this note <passthrough_heuristic>` for more |
| 148 | +# details. |
| 149 | +# |
| 150 | +# Advanced: The ``make_params()`` method |
| 151 | +# -------------------------------------- |
| 152 | +# |
| 153 | +# The ``make_params()`` method is called internally before calling |
| 154 | +# ``transform()`` on each input. This is typically useful to generate random |
| 155 | +# parameter values. In the example below, we use it to randomly apply the |
| 156 | +# transformation with a probability of 0.5 |
| 157 | + |
| 158 | + |
| 159 | +class MyRandomTransform(MyCustomTransform): |
| 160 | + def __init__(self, p=0.5): |
| 161 | + self.p = p |
| 162 | + super().__init__() |
| 163 | + |
| 164 | + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: |
| 165 | + apply_transform = (torch.rand(size=(1,)) < self.p).item() |
| 166 | + params = dict(apply_transform=apply_transform) |
| 167 | + return params |
| 168 | + |
| 169 | + def transform(self, inpt: Any, params: Dict[str, Any]): |
| 170 | + if not params["apply_transform"]: |
| 171 | + print("Not transforming anything!") |
| 172 | + return inpt |
| 173 | + else: |
| 174 | + return super().transform(inpt, params) |
| 175 | + |
| 176 | + |
| 177 | +my_random_transform = MyRandomTransform() |
| 178 | + |
| 179 | +torch.manual_seed(0) |
| 180 | +_ = my_random_transform(structured_input) # transforms |
| 181 | +_ = my_random_transform(structured_input) # doesn't transform |
| 182 | + |
| 183 | +# %% |
| 184 | +# |
| 185 | +# .. note:: |
| 186 | +# |
| 187 | +# It's important for such random parameter generation to happen within |
| 188 | +# ``make_params()`` and not within ``transform()``, so that for a given |
| 189 | +# transform call, the same RNG applies to all the inputs in the same way. If |
| 190 | +# we were to perform the RNG within ``transform()``, we would risk e.g. |
| 191 | +# transforming the image while *not* transforming the bounding boxes. |
| 192 | +# |
| 193 | +# The ``make_params()`` method takes the list of all the inputs as parameter |
| 194 | +# (each of the elements in this list will later be pased to ``transform()``). |
| 195 | +# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input, |
| 196 | +# using :func:`~torchvision.transforms.v2.query_chw` or |
| 197 | +# :func:`~torchvision.transforms.v2.query_size`. |
| 198 | +# |
| 199 | +# ``make_params()`` should return a dict (or actually, anything you want) that |
| 200 | +# will then be passed to ``transform()``. |
0 commit comments