Skip to content

Commit ee28bb3

Browse files
authored
cleanup affine grid image kernels (pytorch#8004)
1 parent f96deba commit ee28bb3

File tree

2 files changed

+39
-94
lines changed

2 files changed

+39
-94
lines changed

test/test_transforms_v2_refactored.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2491,7 +2491,7 @@ def _make_displacement(self, inpt):
24912491
interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
24922492
fill=EXHAUSTIVE_TYPE_FILLS,
24932493
)
2494-
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2494+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8, torch.float16])
24952495
@pytest.mark.parametrize("device", cpu_and_cuda())
24962496
def test_kernel_image(self, param, value, dtype, device):
24972497
image = make_image_tensor(dtype=dtype, device=device)
@@ -2502,6 +2502,7 @@ def test_kernel_image(self, param, value, dtype, device):
25022502
displacement=self._make_displacement(image),
25032503
**{param: value},
25042504
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
2505+
check_cuda_vs_cpu=dtype is not torch.float16,
25052506
)
25062507

25072508
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))

torchvision/transforms/v2/functional/_geometry.py

+37-93
Original file line numberDiff line numberDiff line change
@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
551551

552552

553553
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
554+
input_shape = img.shape
555+
output_height, output_width = grid.shape[1], grid.shape[2]
556+
num_channels, input_height, input_width = input_shape[-3:]
557+
output_shape = input_shape[:-3] + (num_channels, output_height, output_width)
558+
559+
if img.numel() == 0:
560+
return img.reshape(output_shape)
561+
562+
img = img.reshape(-1, num_channels, input_height, input_width)
563+
squashed_batch_size = img.shape[0]
554564

555565
# We are using context knowledge that grid should have float dtype
556566
fp = img.dtype == grid.dtype
557567
float_img = img if fp else img.to(grid.dtype)
558568

559-
shape = float_img.shape
560-
if shape[0] > 1:
569+
if squashed_batch_size > 1:
561570
# Apply same grid to a batch of images
562-
grid = grid.expand(shape[0], -1, -1, -1)
571+
grid = grid.expand(squashed_batch_size, -1, -1, -1)
563572

564573
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
565574
if fill is not None:
566-
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
575+
mask = torch.ones(
576+
(squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device
577+
)
567578
float_img = torch.cat((float_img, mask), dim=1)
568579

569580
float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
584595

585596
img = float_img.round_().to(img.dtype) if not fp else float_img
586597

587-
return img
598+
return img.reshape(output_shape)
588599

589600

590601
def _assert_grid_transform_inputs(
@@ -661,24 +672,10 @@ def affine_image(
661672
) -> torch.Tensor:
662673
interpolation = _check_interpolation(interpolation)
663674

664-
if image.numel() == 0:
665-
return image
666-
667-
shape = image.shape
668-
ndim = image.ndim
669-
670-
if ndim > 4:
671-
image = image.reshape((-1,) + shape[-3:])
672-
needs_unsquash = True
673-
elif ndim == 3:
674-
image = image.unsqueeze(0)
675-
needs_unsquash = True
676-
else:
677-
needs_unsquash = False
678-
679-
height, width = shape[-2:]
680675
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
681676

677+
height, width = image.shape[-2:]
678+
682679
center_f = [0.0, 0.0]
683680
if center is not None:
684681
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
@@ -692,12 +689,7 @@ def affine_image(
692689
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
693690
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
694691
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
695-
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
696-
697-
if needs_unsquash:
698-
output = output.reshape(shape)
699-
700-
return output
692+
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
701693

702694

703695
@_register_kernel_internal(affine, PIL.Image.Image)
@@ -969,35 +961,26 @@ def rotate_image(
969961
) -> torch.Tensor:
970962
interpolation = _check_interpolation(interpolation)
971963

972-
shape = image.shape
973-
num_channels, height, width = shape[-3:]
964+
input_height, input_width = image.shape[-2:]
974965

975966
center_f = [0.0, 0.0]
976967
if center is not None:
977968
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
978-
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
969+
center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])]
979970

980971
# due to current incoherence of rotation angle direction between affine and rotate implementations
981972
# we need to set -angle.
982973
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
983974

984-
if image.numel() > 0:
985-
image = image.reshape(-1, num_channels, height, width)
986-
987-
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
988-
989-
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
990-
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
991-
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
992-
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
993-
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
994-
995-
new_height, new_width = output.shape[-2:]
996-
else:
997-
output = image
998-
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
975+
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
999976

1000-
return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
977+
output_width, output_height = (
978+
_compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height)
979+
)
980+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
981+
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
982+
grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height)
983+
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1001984

1002985

1003986
@_register_kernel_internal(rotate, PIL.Image.Image)
@@ -1509,21 +1492,6 @@ def perspective_image(
15091492
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
15101493
interpolation = _check_interpolation(interpolation)
15111494

1512-
if image.numel() == 0:
1513-
return image
1514-
1515-
shape = image.shape
1516-
ndim = image.ndim
1517-
1518-
if ndim > 4:
1519-
image = image.reshape((-1,) + shape[-3:])
1520-
needs_unsquash = True
1521-
elif ndim == 3:
1522-
image = image.unsqueeze(0)
1523-
needs_unsquash = True
1524-
else:
1525-
needs_unsquash = False
1526-
15271495
_assert_grid_transform_inputs(
15281496
image,
15291497
matrix=None,
@@ -1533,15 +1501,10 @@ def perspective_image(
15331501
coeffs=perspective_coeffs,
15341502
)
15351503

1536-
oh, ow = shape[-2:]
1504+
oh, ow = image.shape[-2:]
15371505
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
15381506
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1539-
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1540-
1541-
if needs_unsquash:
1542-
output = output.reshape(shape)
1543-
1544-
return output
1507+
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
15451508

15461509

15471510
@_register_kernel_internal(perspective, PIL.Image.Image)
@@ -1759,12 +1722,7 @@ def elastic_image(
17591722

17601723
interpolation = _check_interpolation(interpolation)
17611724

1762-
if image.numel() == 0:
1763-
return image
1764-
1765-
shape = image.shape
1766-
ndim = image.ndim
1767-
1725+
height, width = image.shape[-2:]
17681726
device = image.device
17691727
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
17701728

@@ -1775,32 +1733,18 @@ def elastic_image(
17751733
dtype = torch.float32
17761734

17771735
# We are aware that if input image dtype is uint8 and displacement is float64 then
1778-
# displacement will be casted to float32 and all computations will be done with float32
1736+
# displacement will be cast to float32 and all computations will be done with float32
17791737
# We can fix this later if needed
17801738

1781-
expected_shape = (1,) + shape[-2:] + (2,)
1739+
expected_shape = (1, height, width, 2)
17821740
if expected_shape != displacement.shape:
17831741
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
17841742

1785-
if ndim > 4:
1786-
image = image.reshape((-1,) + shape[-3:])
1787-
needs_unsquash = True
1788-
elif ndim == 3:
1789-
image = image.unsqueeze(0)
1790-
needs_unsquash = True
1791-
else:
1792-
needs_unsquash = False
1793-
1794-
if displacement.dtype != dtype or displacement.device != device:
1795-
displacement = displacement.to(dtype=dtype, device=device)
1796-
1797-
image_height, image_width = shape[-2:]
1798-
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
1743+
grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
1744+
displacement.to(dtype=dtype, device=device)
1745+
)
17991746
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
18001747

1801-
if needs_unsquash:
1802-
output = output.reshape(shape)
1803-
18041748
if is_cpu_half:
18051749
output = output.to(torch.float16)
18061750

0 commit comments

Comments
 (0)