Skip to content

Commit e9cad2d

Browse files
committed
fix elastic float16 output
1 parent 5bac8ca commit e9cad2d

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

test/test_transforms_v2_refactored.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2491,7 +2491,7 @@ def _make_displacement(self, inpt):
24912491
interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
24922492
fill=EXHAUSTIVE_TYPE_FILLS,
24932493
)
2494-
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2494+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8, torch.float16])
24952495
@pytest.mark.parametrize("device", cpu_and_cuda())
24962496
def test_kernel_image(self, param, value, dtype, device):
24972497
image = make_image_tensor(dtype=dtype, device=device)
@@ -2502,6 +2502,7 @@ def test_kernel_image(self, param, value, dtype, device):
25022502
displacement=self._make_displacement(image),
25032503
**{param: value},
25042504
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
2505+
check_cuda_vs_cpu=dtype is not torch.float16,
25052506
)
25062507

25072508
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))

torchvision/transforms/v2/functional/_geometry.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1743,7 +1743,12 @@ def elastic_image(
17431743
grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
17441744
displacement.to(dtype=dtype, device=device)
17451745
)
1746-
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1746+
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1747+
1748+
if is_cpu_half:
1749+
output = output.to(torch.float16)
1750+
1751+
return output
17471752

17481753

17491754
@_register_kernel_internal(elastic, PIL.Image.Image)

0 commit comments

Comments
 (0)