10
10
import torch
11
11
import torch .fx
12
12
import torch .nn .functional as F
13
- from common_utils import assert_equal , cpu_and_cuda , needs_cuda
13
+ from common_utils import assert_equal , cpu_and_cuda , cpu_and_cuda_and_mps , needs_cuda , needs_mps
14
14
from PIL import Image
15
15
from torch import nn , Tensor
16
16
from torch .autograd import gradcheck
@@ -96,12 +96,33 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
96
96
97
97
class RoIOpTester (ABC ):
98
98
dtype = torch .float64
99
+ mps_dtype = torch .float32
100
+ mps_backward_atol = 2e-2
99
101
100
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
102
+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
101
103
@pytest .mark .parametrize ("contiguous" , (True , False ))
102
- def test_forward (self , device , contiguous , x_dtype = None , rois_dtype = None , deterministic = False , ** kwargs ):
103
- x_dtype = self .dtype if x_dtype is None else x_dtype
104
- rois_dtype = self .dtype if rois_dtype is None else rois_dtype
104
+ @pytest .mark .parametrize (
105
+ "x_dtype" ,
106
+ (
107
+ torch .float16 ,
108
+ torch .float32 ,
109
+ torch .float64 ,
110
+ ),
111
+ ids = str ,
112
+ )
113
+ def test_forward (self , device , contiguous , x_dtype , rois_dtype = None , deterministic = False , ** kwargs ):
114
+ if device == "mps" and x_dtype is torch .float64 :
115
+ pytest .skip ("MPS does not support float64" )
116
+
117
+ rois_dtype = x_dtype if rois_dtype is None else rois_dtype
118
+
119
+ tol = 1e-5
120
+ if x_dtype is torch .half :
121
+ if device == "mps" :
122
+ tol = 5e-3
123
+ else :
124
+ tol = 4e-3
125
+
105
126
pool_size = 5
106
127
# n_channels % (pool_size ** 2) == 0 required for PS operations.
107
128
n_channels = 2 * (pool_size ** 2 )
@@ -120,10 +141,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ
120
141
# the following should be true whether we're running an autocast test or not.
121
142
assert y .dtype == x .dtype
122
143
gt_y = self .expected_fn (
123
- x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , device = device , dtype = self . dtype , ** kwargs
144
+ x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , device = device , dtype = x_dtype , ** kwargs
124
145
)
125
146
126
- tol = 1e-3 if (x_dtype is torch .half or rois_dtype is torch .half ) else 1e-5
127
147
torch .testing .assert_close (gt_y .to (y ), y , rtol = tol , atol = tol )
128
148
129
149
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
@@ -155,16 +175,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
155
175
torch .testing .assert_close (output_gt , output_fx , rtol = tol , atol = tol )
156
176
157
177
@pytest .mark .parametrize ("seed" , range (10 ))
158
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
178
+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
159
179
@pytest .mark .parametrize ("contiguous" , (True , False ))
160
180
def test_backward (self , seed , device , contiguous , deterministic = False ):
181
+ atol = self .mps_backward_atol if device == "mps" else 1e-05
182
+ dtype = self .mps_dtype if device == "mps" else self .dtype
183
+
161
184
torch .random .manual_seed (seed )
162
185
pool_size = 2
163
- x = torch .rand (1 , 2 * (pool_size ** 2 ), 5 , 5 , dtype = self . dtype , device = device , requires_grad = True )
186
+ x = torch .rand (1 , 2 * (pool_size ** 2 ), 5 , 5 , dtype = dtype , device = device , requires_grad = True )
164
187
if not contiguous :
165
188
x = x .permute (0 , 1 , 3 , 2 )
166
189
rois = torch .tensor (
167
- [[0 , 0 , 0 , 4 , 4 ], [0 , 0 , 2 , 3 , 4 ], [0 , 2 , 2 , 4 , 4 ]], dtype = self . dtype , device = device # format is (xyxy)
190
+ [[0 , 0 , 0 , 4 , 4 ], [0 , 0 , 2 , 3 , 4 ], [0 , 2 , 2 , 4 , 4 ]], dtype = dtype , device = device # format is (xyxy)
168
191
)
169
192
170
193
def func (z ):
@@ -173,9 +196,25 @@ def func(z):
173
196
script_func = self .get_script_fn (rois , pool_size )
174
197
175
198
with DeterministicGuard (deterministic ):
176
- gradcheck (func , (x ,))
199
+ gradcheck (func , (x ,), atol = atol )
200
+
201
+ gradcheck (script_func , (x ,), atol = atol )
177
202
178
- gradcheck (script_func , (x ,))
203
+ @needs_mps
204
+ def test_mps_error_inputs (self ):
205
+ pool_size = 2
206
+ x = torch .rand (1 , 2 * (pool_size ** 2 ), 5 , 5 , dtype = torch .float16 , device = "mps" , requires_grad = True )
207
+ rois = torch .tensor (
208
+ [[0 , 0 , 0 , 4 , 4 ], [0 , 0 , 2 , 3 , 4 ], [0 , 2 , 2 , 4 , 4 ]], dtype = torch .float16 , device = "mps" # format is (xyxy)
209
+ )
210
+
211
+ def func (z ):
212
+ return self .fn (z , rois , pool_size , pool_size , spatial_scale = 1 , sampling_ratio = 1 )
213
+
214
+ with pytest .raises (
215
+ RuntimeError , match = "MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
216
+ ):
217
+ gradcheck (func , (x ,))
179
218
180
219
@needs_cuda
181
220
@pytest .mark .parametrize ("x_dtype" , (torch .float , torch .half ))
@@ -271,6 +310,8 @@ def test_jit_boxes_list(self):
271
310
272
311
273
312
class TestPSRoIPool (RoIOpTester ):
313
+ mps_backward_atol = 5e-2
314
+
274
315
def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
275
316
return ops .PSRoIPool ((pool_h , pool_w ), 1 )(x , rois )
276
317
@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):
352
393
353
394
354
395
class TestRoIAlign (RoIOpTester ):
396
+ mps_backward_atol = 6e-2
397
+
355
398
def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , aligned = False , ** kwargs ):
356
399
return ops .RoIAlign (
357
400
(pool_h , pool_w ), spatial_scale = spatial_scale , sampling_ratio = sampling_ratio , aligned = aligned
@@ -418,10 +461,11 @@ def test_boxes_shape(self):
418
461
self ._helper_boxes_shape (ops .roi_align )
419
462
420
463
@pytest .mark .parametrize ("aligned" , (True , False ))
421
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
464
+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
465
+ @pytest .mark .parametrize ("x_dtype" , (torch .float16 , torch .float32 , torch .float64 ), ids = str )
422
466
@pytest .mark .parametrize ("contiguous" , (True , False ))
423
467
@pytest .mark .parametrize ("deterministic" , (True , False ))
424
- def test_forward (self , device , contiguous , deterministic , aligned , x_dtype = None , rois_dtype = None ):
468
+ def test_forward (self , device , contiguous , deterministic , aligned , x_dtype , rois_dtype = None ):
425
469
if deterministic and device == "cpu" :
426
470
pytest .skip ("cpu is always deterministic, don't retest" )
427
471
super ().test_forward (
@@ -450,7 +494,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
450
494
)
451
495
452
496
@pytest .mark .parametrize ("seed" , range (10 ))
453
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
497
+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
454
498
@pytest .mark .parametrize ("contiguous" , (True , False ))
455
499
@pytest .mark .parametrize ("deterministic" , (True , False ))
456
500
def test_backward (self , seed , device , contiguous , deterministic ):
@@ -537,6 +581,8 @@ def test_jit_boxes_list(self):
537
581
538
582
539
583
class TestPSRoIAlign (RoIOpTester ):
584
+ mps_backward_atol = 5e-2
585
+
540
586
def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
541
587
return ops .PSRoIAlign ((pool_h , pool_w ), spatial_scale = spatial_scale , sampling_ratio = sampling_ratio )(x , rois )
542
588
@@ -705,40 +751,53 @@ def test_qnms(self, iou, scale, zero_point):
705
751
706
752
torch .testing .assert_close (qkeep , keep , msg = err_msg .format (iou ))
707
753
708
- @needs_cuda
754
+ @pytest .mark .parametrize (
755
+ "device" ,
756
+ (
757
+ pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
758
+ pytest .param ("mps" , marks = pytest .mark .needs_mps ),
759
+ ),
760
+ )
709
761
@pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
710
- def test_nms_cuda (self , iou , dtype = torch .float64 ):
762
+ def test_nms_gpu (self , iou , device , dtype = torch .float64 ):
763
+ dtype = torch .float32 if device == "mps" else dtype
711
764
tol = 1e-3 if dtype is torch .half else 1e-5
712
765
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
713
766
714
767
boxes , scores = self ._create_tensors_with_iou (1000 , iou )
715
768
r_cpu = ops .nms (boxes , scores , iou )
716
- r_cuda = ops .nms (boxes .cuda ( ), scores .cuda ( ), iou )
769
+ r_gpu = ops .nms (boxes .to ( device ), scores .to ( device ), iou )
717
770
718
- is_eq = torch .allclose (r_cpu , r_cuda .cpu ())
771
+ is_eq = torch .allclose (r_cpu , r_gpu .cpu ())
719
772
if not is_eq :
720
773
# if the indices are not the same, ensure that it's because the scores
721
774
# are duplicate
722
- is_eq = torch .allclose (scores [r_cpu ], scores [r_cuda .cpu ()], rtol = tol , atol = tol )
775
+ is_eq = torch .allclose (scores [r_cpu ], scores [r_gpu .cpu ()], rtol = tol , atol = tol )
723
776
assert is_eq , err_msg .format (iou )
724
777
725
778
@needs_cuda
726
779
@pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
727
780
@pytest .mark .parametrize ("dtype" , (torch .float , torch .half ))
728
781
def test_autocast (self , iou , dtype ):
729
782
with torch .cuda .amp .autocast ():
730
- self .test_nms_cuda (iou = iou , dtype = dtype )
783
+ self .test_nms_gpu (iou = iou , dtype = dtype , device = "cuda" )
731
784
732
- @needs_cuda
733
- def test_nms_cuda_float16 (self ):
785
+ @pytest .mark .parametrize (
786
+ "device" ,
787
+ (
788
+ pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
789
+ pytest .param ("mps" , marks = pytest .mark .needs_mps ),
790
+ ),
791
+ )
792
+ def test_nms_float16 (self , device ):
734
793
boxes = torch .tensor (
735
794
[
736
795
[285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
737
796
[285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
738
797
[279.2440 , 197.9812 , 1189.4746 , 849.2019 ],
739
798
]
740
- ).cuda ( )
741
- scores = torch .tensor ([0.6370 , 0.7569 , 0.3966 ]).cuda ( )
799
+ ).to ( device )
800
+ scores = torch .tensor ([0.6370 , 0.7569 , 0.3966 ]).to ( device )
742
801
743
802
iou_thres = 0.2
744
803
keep32 = ops .nms (boxes , scores , iou_thres )
0 commit comments