@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
551
551
552
552
553
553
def _apply_grid_transform (img : torch .Tensor , grid : torch .Tensor , mode : str , fill : _FillTypeJIT ) -> torch .Tensor :
554
+ input_shape = img .shape
555
+ output_height , output_width = grid .shape [1 ], grid .shape [2 ]
556
+ num_channels , input_height , input_width = input_shape [- 3 :]
557
+ output_shape = input_shape [:- 3 ] + (num_channels , output_height , output_width )
558
+
559
+ if img .numel () == 0 :
560
+ return img .reshape (output_shape )
561
+
562
+ img = img .reshape (- 1 , num_channels , input_height , input_width )
563
+ squashed_batch_size = img .shape [0 ]
554
564
555
565
# We are using context knowledge that grid should have float dtype
556
566
fp = img .dtype == grid .dtype
557
567
float_img = img if fp else img .to (grid .dtype )
558
568
559
- shape = float_img .shape
560
- if shape [0 ] > 1 :
569
+ if squashed_batch_size > 1 :
561
570
# Apply same grid to a batch of images
562
- grid = grid .expand (shape [ 0 ] , - 1 , - 1 , - 1 )
571
+ grid = grid .expand (squashed_batch_size , - 1 , - 1 , - 1 )
563
572
564
573
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
565
574
if fill is not None :
566
- mask = torch .ones ((shape [0 ], 1 , shape [2 ], shape [3 ]), dtype = float_img .dtype , device = float_img .device )
575
+ mask = torch .ones (
576
+ (squashed_batch_size , 1 , input_height , input_width ), dtype = float_img .dtype , device = float_img .device
577
+ )
567
578
float_img = torch .cat ((float_img , mask ), dim = 1 )
568
579
569
580
float_img = grid_sample (float_img , grid , mode = mode , padding_mode = "zeros" , align_corners = False )
@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
584
595
585
596
img = float_img .round_ ().to (img .dtype ) if not fp else float_img
586
597
587
- return img
598
+ return img . reshape ( output_shape )
588
599
589
600
590
601
def _assert_grid_transform_inputs (
@@ -661,24 +672,10 @@ def affine_image(
661
672
) -> torch .Tensor :
662
673
interpolation = _check_interpolation (interpolation )
663
674
664
- if image .numel () == 0 :
665
- return image
666
-
667
- shape = image .shape
668
- ndim = image .ndim
669
-
670
- if ndim > 4 :
671
- image = image .reshape ((- 1 ,) + shape [- 3 :])
672
- needs_unsquash = True
673
- elif ndim == 3 :
674
- image = image .unsqueeze (0 )
675
- needs_unsquash = True
676
- else :
677
- needs_unsquash = False
678
-
679
- height , width = shape [- 2 :]
680
675
angle , translate , shear , center = _affine_parse_args (angle , translate , scale , shear , interpolation , center )
681
676
677
+ height , width = image .shape [- 2 :]
678
+
682
679
center_f = [0.0 , 0.0 ]
683
680
if center is not None :
684
681
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
@@ -692,12 +689,7 @@ def affine_image(
692
689
dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
693
690
theta = torch .tensor (matrix , dtype = dtype , device = image .device ).reshape (1 , 2 , 3 )
694
691
grid = _affine_grid (theta , w = width , h = height , ow = width , oh = height )
695
- output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
696
-
697
- if needs_unsquash :
698
- output = output .reshape (shape )
699
-
700
- return output
692
+ return _apply_grid_transform (image , grid , interpolation .value , fill = fill )
701
693
702
694
703
695
@_register_kernel_internal (affine , PIL .Image .Image )
@@ -969,35 +961,26 @@ def rotate_image(
969
961
) -> torch .Tensor :
970
962
interpolation = _check_interpolation (interpolation )
971
963
972
- shape = image .shape
973
- num_channels , height , width = shape [- 3 :]
964
+ input_height , input_width = image .shape [- 2 :]
974
965
975
966
center_f = [0.0 , 0.0 ]
976
967
if center is not None :
977
968
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
978
- center_f = [(c - s * 0.5 ) for c , s in zip (center , [width , height ])]
969
+ center_f = [(c - s * 0.5 ) for c , s in zip (center , [input_width , input_height ])]
979
970
980
971
# due to current incoherence of rotation angle direction between affine and rotate implementations
981
972
# we need to set -angle.
982
973
matrix = _get_inverse_affine_matrix (center_f , - angle , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ])
983
974
984
- if image .numel () > 0 :
985
- image = image .reshape (- 1 , num_channels , height , width )
986
-
987
- _assert_grid_transform_inputs (image , matrix , interpolation .value , fill , ["nearest" , "bilinear" ])
988
-
989
- ow , oh = _compute_affine_output_size (matrix , width , height ) if expand else (width , height )
990
- dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
991
- theta = torch .tensor (matrix , dtype = dtype , device = image .device ).reshape (1 , 2 , 3 )
992
- grid = _affine_grid (theta , w = width , h = height , ow = ow , oh = oh )
993
- output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
994
-
995
- new_height , new_width = output .shape [- 2 :]
996
- else :
997
- output = image
998
- new_width , new_height = _compute_affine_output_size (matrix , width , height ) if expand else (width , height )
975
+ _assert_grid_transform_inputs (image , matrix , interpolation .value , fill , ["nearest" , "bilinear" ])
999
976
1000
- return output .reshape (shape [:- 3 ] + (num_channels , new_height , new_width ))
977
+ output_width , output_height = (
978
+ _compute_affine_output_size (matrix , input_width , input_height ) if expand else (input_width , input_height )
979
+ )
980
+ dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
981
+ theta = torch .tensor (matrix , dtype = dtype , device = image .device ).reshape (1 , 2 , 3 )
982
+ grid = _affine_grid (theta , w = input_width , h = input_height , ow = output_width , oh = output_height )
983
+ return _apply_grid_transform (image , grid , interpolation .value , fill = fill )
1001
984
1002
985
1003
986
@_register_kernel_internal (rotate , PIL .Image .Image )
@@ -1509,21 +1492,6 @@ def perspective_image(
1509
1492
perspective_coeffs = _perspective_coefficients (startpoints , endpoints , coefficients )
1510
1493
interpolation = _check_interpolation (interpolation )
1511
1494
1512
- if image .numel () == 0 :
1513
- return image
1514
-
1515
- shape = image .shape
1516
- ndim = image .ndim
1517
-
1518
- if ndim > 4 :
1519
- image = image .reshape ((- 1 ,) + shape [- 3 :])
1520
- needs_unsquash = True
1521
- elif ndim == 3 :
1522
- image = image .unsqueeze (0 )
1523
- needs_unsquash = True
1524
- else :
1525
- needs_unsquash = False
1526
-
1527
1495
_assert_grid_transform_inputs (
1528
1496
image ,
1529
1497
matrix = None ,
@@ -1533,15 +1501,10 @@ def perspective_image(
1533
1501
coeffs = perspective_coeffs ,
1534
1502
)
1535
1503
1536
- oh , ow = shape [- 2 :]
1504
+ oh , ow = image . shape [- 2 :]
1537
1505
dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
1538
1506
grid = _perspective_grid (perspective_coeffs , ow = ow , oh = oh , dtype = dtype , device = image .device )
1539
- output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
1540
-
1541
- if needs_unsquash :
1542
- output = output .reshape (shape )
1543
-
1544
- return output
1507
+ return _apply_grid_transform (image , grid , interpolation .value , fill = fill )
1545
1508
1546
1509
1547
1510
@_register_kernel_internal (perspective , PIL .Image .Image )
@@ -1759,12 +1722,7 @@ def elastic_image(
1759
1722
1760
1723
interpolation = _check_interpolation (interpolation )
1761
1724
1762
- if image .numel () == 0 :
1763
- return image
1764
-
1765
- shape = image .shape
1766
- ndim = image .ndim
1767
-
1725
+ height , width = image .shape [- 2 :]
1768
1726
device = image .device
1769
1727
dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
1770
1728
@@ -1775,32 +1733,18 @@ def elastic_image(
1775
1733
dtype = torch .float32
1776
1734
1777
1735
# We are aware that if input image dtype is uint8 and displacement is float64 then
1778
- # displacement will be casted to float32 and all computations will be done with float32
1736
+ # displacement will be cast to float32 and all computations will be done with float32
1779
1737
# We can fix this later if needed
1780
1738
1781
- expected_shape = (1 ,) + shape [ - 2 :] + ( 2 , )
1739
+ expected_shape = (1 , height , width , 2 )
1782
1740
if expected_shape != displacement .shape :
1783
1741
raise ValueError (f"Argument displacement shape should be { expected_shape } , but given { displacement .shape } " )
1784
1742
1785
- if ndim > 4 :
1786
- image = image .reshape ((- 1 ,) + shape [- 3 :])
1787
- needs_unsquash = True
1788
- elif ndim == 3 :
1789
- image = image .unsqueeze (0 )
1790
- needs_unsquash = True
1791
- else :
1792
- needs_unsquash = False
1793
-
1794
- if displacement .dtype != dtype or displacement .device != device :
1795
- displacement = displacement .to (dtype = dtype , device = device )
1796
-
1797
- image_height , image_width = shape [- 2 :]
1798
- grid = _create_identity_grid ((image_height , image_width ), device = device , dtype = dtype ).add_ (displacement )
1743
+ grid = _create_identity_grid ((height , width ), device = device , dtype = dtype ).add_ (
1744
+ displacement .to (dtype = dtype , device = device )
1745
+ )
1799
1746
output = _apply_grid_transform (image , grid , interpolation .value , fill = fill )
1800
1747
1801
- if needs_unsquash :
1802
- output = output .reshape (shape )
1803
-
1804
1748
if is_cpu_half :
1805
1749
output = output .to (torch .float16 )
1806
1750
0 commit comments