-
Notifications
You must be signed in to change notification settings - Fork 241
/
Copy pathrandom_affine.py
470 lines (431 loc) · 18.4 KB
/
random_affine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
from numbers import Number
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import SimpleITK as sitk
import torch
from ....constants import INTENSITY
from ....constants import TYPE
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....typing import TypeRangeFloat
from ....typing import TypeSextetFloat
from ....typing import TypeTripletFloat
from ....utils import get_major_sitk_version
from ....utils import to_tuple
from ...spatial_transform import SpatialTransform
from .. import RandomTransform
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
class RandomAffine(RandomTransform, SpatialTransform):
r"""Apply a random affine transformation and resample the image.
Args:
scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
scaling ranges.
The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
If two values :math:`(a, b)` are provided,
then :math:`s_i \sim \mathcal{U}(a, b)`.
If only one value :math:`x` is provided,
then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
If three values :math:`(x_1, x_2, x_3)` are provided,
then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
making the objects inside look twice as small while preserving
the physical size and position of the image bounds.
degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
rotation ranges in degrees.
Rotation angles around each axis are
:math:`(\theta_1, \theta_2, \theta_3)`,
where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
If two values :math:`(a, b)` are provided,
then :math:`\theta_i \sim \mathcal{U}(a, b)`.
If only one value :math:`x` is provided,
then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
If three values :math:`(x_1, x_2, x_3)` are provided,
then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
translation ranges in mm.
Translation along each axis is :math:`(t_1, t_2, t_3)`,
where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
If two values :math:`(a, b)` are provided,
then :math:`t_i \sim \mathcal{U}(a, b)`.
If only one value :math:`x` is provided,
then :math:`t_i \sim \mathcal{U}(-x, x)`.
If three values :math:`(x_1, x_2, x_3)` are provided,
then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
For example, if the image is in RAS+ orientation (e.g., after
applying :class:`~torchio.transforms.preprocessing.ToCanonical`)
and the translation is :math:`(10, 20, 30)`, the sample will move
10 mm to the right, 20 mm to the front, and 30 mm upwards.
If the image was in, e.g., PIR+ orientation, the sample will move
10 mm to the back, 20 mm downwards, and 30 mm to the right.
isotropic: If ``True``, only one scaling factor will be sampled for all dimensions,
i.e. :math:`s_1 = s_2 = s_3`.
If one value :math:`x` is provided in :attr:`scales`, the scaling factor along all
dimensions will be :math:`s \sim \mathcal{U}(1 - x, 1 + x)`.
If two values provided :math:`(a, b)` in :attr:`scales`, the scaling factor along all
dimensions will be :math:`s \sim \mathcal{U}(a, b)`.
center: If ``'image'``, rotations and scaling will be performed around
the image center. If ``'origin'``, rotations and scaling will be
performed around the origin in world coordinates.
default_pad_value: As the image is rotated, some values near the
borders will be undefined.
If ``'minimum'``, the fill value will be the image minimum.
If ``'mean'``, the fill value is the mean of the border values.
If ``'otsu'``, the fill value is the mean of the values at the
border that lie under an
`Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
If it is a number, that value will be used.
image_interpolation: See :ref:`Interpolation`.
label_interpolation: See :ref:`Interpolation`.
check_shape: If ``True`` an error will be raised if the images are in
different physical spaces. If ``False``, :attr:`center` should
probably not be ``'image'`` but ``'center'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Example:
>>> import torchio as tio
>>> image = tio.datasets.Colin27().t1
>>> transform = tio.RandomAffine(
... scales=(0.9, 1.2),
... degrees=15,
... )
>>> transformed = transform(image)
.. plot::
import torchio as tio
subject = tio.datasets.Slicer('CTChest')
ct = subject.CT_chest
transform = tio.RandomAffine()
ct_transformed = transform(ct)
subject.add_image(ct_transformed, 'Transformed')
subject.plot()
"""
def __init__(
self,
scales: TypeOneToSixFloat = 0.1,
degrees: TypeOneToSixFloat = 10,
translation: TypeOneToSixFloat = 0,
isotropic: bool = False,
center: str = 'image',
default_pad_value: Union[str, float] = 'minimum',
image_interpolation: str = 'linear',
label_interpolation: str = 'nearest',
check_shape: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.isotropic = isotropic
_parse_scales_isotropic(scales, isotropic)
self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
self.degrees = self.parse_params(degrees, 0, 'degrees')
self.translation = self.parse_params(translation, 0, 'translation')
if center not in ('image', 'origin'):
message = f'Center argument must be "image" or "origin", not "{center}"'
raise ValueError(message)
self.center = center
self.default_pad_value = _parse_default_value(default_pad_value)
self.image_interpolation = self.parse_interpolation(
image_interpolation,
)
self.label_interpolation = self.parse_interpolation(
label_interpolation,
)
self.check_shape = check_shape
def get_params(
self,
scales: TypeSextetFloat,
degrees: TypeSextetFloat,
translation: TypeSextetFloat,
isotropic: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scaling_params = torch.as_tensor(
self.sample_uniform_sextet(scales),
dtype=torch.float64,
)
if isotropic:
scaling_params.fill_(scaling_params[0])
rotation_params = torch.as_tensor(
self.sample_uniform_sextet(degrees),
dtype=torch.float64,
)
translation_params = torch.as_tensor(
self.sample_uniform_sextet(translation),
dtype=torch.float64,
)
return scaling_params, rotation_params, translation_params
def apply_transform(self, subject: Subject) -> Subject:
scaling_params, rotation_params, translation_params = self.get_params(
self.scales,
self.degrees,
self.translation,
self.isotropic,
)
arguments = {
'scales': scaling_params,
'degrees': rotation_params,
'translation': translation_params,
'center': self.center,
'default_pad_value': self.default_pad_value,
'image_interpolation': self.image_interpolation,
'label_interpolation': self.label_interpolation,
'check_shape': self.check_shape,
}
transform = Affine(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
class Affine(SpatialTransform):
r"""Apply affine transformation.
Args:
scales: Tuple :math:`(s_1, s_2, s_3)` defining the
scaling values along each dimension.
degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
rotation around each axis.
translation: Tuple :math:`(t_1, t_2, t_3)` defining the
translation in mm along each axis.
center: If ``'image'``, rotations and scaling will be performed around
the image center. If ``'origin'``, rotations and scaling will be
performed around the origin in world coordinates.
default_pad_value: As the image is rotated, some values near the
borders will be undefined.
If ``'minimum'``, the fill value will be the image minimum.
If ``'mean'``, the fill value is the mean of the border values.
If ``'otsu'``, the fill value is the mean of the values at the
border that lie under an
`Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
If it is a number, that value will be used.
image_interpolation: See :ref:`Interpolation`.
label_interpolation: See :ref:`Interpolation`.
check_shape: If ``True`` an error will be raised if the images are in
different physical spaces. If ``False``, :attr:`center` should
probably not be ``'image'`` but ``'center'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
scales: TypeTripletFloat,
degrees: TypeTripletFloat,
translation: TypeTripletFloat,
center: str = 'image',
default_pad_value: Union[str, float] = 'minimum',
image_interpolation: str = 'linear',
label_interpolation: str = 'nearest',
check_shape: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.scales = self.parse_params(
scales,
None,
'scales',
make_ranges=False,
min_constraint=0,
)
self.degrees = self.parse_params(
degrees,
None,
'degrees',
make_ranges=False,
)
self.translation = self.parse_params(
translation,
None,
'translation',
make_ranges=False,
)
if center not in ('image', 'origin'):
message = f'Center argument must be "image" or "origin", not "{center}"'
raise ValueError(message)
self.center = center
self.use_image_center = center == 'image'
self.default_pad_value = _parse_default_value(default_pad_value)
self.image_interpolation = self.parse_interpolation(
image_interpolation,
)
self.label_interpolation = self.parse_interpolation(
label_interpolation,
)
self.invert_transform = False
self.check_shape = check_shape
self.args_names = [
'scales',
'degrees',
'translation',
'center',
'default_pad_value',
'image_interpolation',
'label_interpolation',
'check_shape',
]
@staticmethod
def _get_scaling_transform(
scaling_params: Sequence[float],
center_lps: Optional[TypeTripletFloat] = None,
) -> sitk.ScaleTransform:
# 1.5 means the objects look 1.5 times larger
transform = sitk.ScaleTransform(3)
scaling_params_array = np.array(scaling_params).astype(float)
transform.SetScale(scaling_params_array)
if center_lps is not None:
transform.SetCenter(center_lps)
return transform
@staticmethod
def _get_rotation_transform(
degrees: Sequence[float],
translation: Sequence[float],
center_lps: Optional[TypeTripletFloat] = None,
) -> sitk.Euler3DTransform:
def ras_to_lps(triplet: Sequence[float]):
return np.array((-1, -1, 1), dtype=float) * np.asarray(triplet)
transform = sitk.Euler3DTransform()
radians = np.radians(degrees).tolist()
# SimpleITK uses LPS
radians_lps = ras_to_lps(radians)
translation_lps = ras_to_lps(translation)
transform.SetRotation(*radians_lps)
transform.SetTranslation(translation_lps)
if center_lps is not None:
transform.SetCenter(center_lps)
return transform
def get_affine_transform(self, image):
scaling = np.asarray(self.scales).copy()
rotation = np.asarray(self.degrees).copy()
translation = np.asarray(self.translation).copy()
if image.is_2d():
scaling[2] = 1
rotation[:-1] = 0
if self.use_image_center:
center_lps = image.get_center(lps=True)
else:
center_lps = None
scaling_transform = self._get_scaling_transform(
scaling,
center_lps=center_lps,
)
rotation_transform = self._get_rotation_transform(
rotation,
translation,
center_lps=center_lps,
)
sitk_major_version = get_major_sitk_version()
if sitk_major_version == 1:
transform = sitk.Transform(3, sitk.sitkComposite)
transform.AddTransform(scaling_transform)
transform.AddTransform(rotation_transform)
elif sitk_major_version == 2:
transforms = [scaling_transform, rotation_transform]
transform = sitk.CompositeTransform(transforms)
# ResampleImageFilter expects the transform from the output space to
# the input space. Intuitively, the passed arguments should take us
# from the input space to the output space, so we need to invert the
# transform.
# More info at https://github.com/fepegar/torchio/discussions/693
transform = transform.GetInverse()
if self.invert_transform:
transform = transform.GetInverse()
return transform
def apply_transform(self, subject: Subject) -> Subject:
if self.check_shape:
subject.check_consistent_spatial_shape()
default_value: float
for image in self.get_images(subject):
transform = self.get_affine_transform(image)
transformed_tensors = []
for tensor in image.data:
sitk_image = nib_to_sitk(
tensor[np.newaxis],
image.affine,
force_3d=True,
)
if image[TYPE] != INTENSITY:
interpolation = self.label_interpolation
default_value = 0
else:
interpolation = self.image_interpolation
if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(
sitk_image,
filter_otsu=False,
)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(
sitk_image,
filter_otsu=True,
)
else:
assert isinstance(self.default_pad_value, Number)
default_value = float(self.default_pad_value)
transformed_tensor = self.apply_affine_transform(
sitk_image,
transform,
interpolation,
default_value,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
return subject
def apply_affine_transform(
self,
sitk_image: sitk.Image,
transform: sitk.Transform,
interpolation: str,
default_value: float,
) -> torch.Tensor:
floating = reference = sitk_image
resampler = sitk.ResampleImageFilter()
resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
resampler.SetReferenceImage(reference)
resampler.SetDefaultPixelValue(float(default_value))
resampler.SetOutputPixelType(sitk.sitkFloat32)
resampler.SetTransform(transform)
resampled = resampler.Execute(floating)
np_array = sitk.GetArrayFromImage(resampled)
np_array = np_array.transpose() # ITK to NumPy
tensor = torch.as_tensor(np_array)
return tensor
def get_borders_mean(image, filter_otsu=True):
array = sitk.GetArrayViewFromImage(image)
borders_tuple = (
array[0, :, :],
array[-1, :, :],
array[:, 0, :],
array[:, -1, :],
array[:, :, 0],
array[:, :, -1],
)
borders_flat = np.hstack([border.ravel() for border in borders_tuple])
if not filter_otsu:
return borders_flat.mean()
borders_reshaped = borders_flat.reshape(1, 1, -1)
borders_image = sitk.GetImageFromArray(borders_reshaped)
otsu = sitk.OtsuThresholdImageFilter()
otsu.Execute(borders_image)
threshold = otsu.GetThreshold()
values = borders_flat[borders_flat < threshold]
if values.any():
default_value = values.mean()
else:
default_value = borders_flat.mean()
return default_value
def _parse_scales_isotropic(scales, isotropic):
scales = to_tuple(scales)
if isotropic and len(scales) in (3, 6):
message = (
'If "isotropic" is True, the value for "scales" must have'
f' length 1 or 2, but "{scales}" was passed.'
' If you want to set isotropic scaling, use a single value or two values as a range'
' for the scaling factor. Refer to the documentation for more information.'
)
raise ValueError(message)
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
return value
message = (
'Value for default_pad_value must be "minimum", "otsu", "mean" or a number'
)
raise ValueError(message)