Skip to content

Commit 16d62e3

Browse files
qqaatwNicolasHug
andauthored
Add MPS kernels for nms and roi ops (pytorch#7643)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent f524cd3 commit 16d62e3

15 files changed

+2146
-32
lines changed

.github/scripts/run-clang-format.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
DEVNULL = open(os.devnull, "wb")
4949

5050

51-
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
51+
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm"
5252

5353

5454
class ExitStatus:

CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
44
file(STRINGS version.txt TORCHVISION_VERSION)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
7+
option(WITH_MPS "Enable MPS support" OFF)
78
option(WITH_PNG "Enable features requiring LibPNG." ON)
89
option(WITH_JPEG "Enable features requiring LibJPEG." ON)
910
option(USE_PYTHON "Link to Python when building" OFF)
@@ -15,6 +16,11 @@ if(WITH_CUDA)
1516
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
1617
endif()
1718

19+
if(WITH_MPS)
20+
enable_language(OBJC OBJCXX)
21+
add_definitions(-DWITH_MPS)
22+
endif()
23+
1824
find_package(Torch REQUIRED)
1925

2026
if (WITH_PNG)
@@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP
7985
if(WITH_CUDA)
8086
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
8187
endif()
88+
if(WITH_MPS)
89+
list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps)
90+
endif()
8291

8392
FOREACH(DIR ${ALLOW_LISTED})
8493
file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*)

setup.py

+5
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,13 @@ def get_extensions():
137137
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
138138
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
139139
)
140+
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
140141

141142
print("Compiling extensions with following flags:")
142143
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
143144
print(f" FORCE_CUDA: {force_cuda}")
145+
force_mps = os.getenv("FORCE_MPS", "0") == "1"
146+
print(f" FORCE_MPS: {force_mps}")
144147
debug_mode = os.getenv("DEBUG", "0") == "1"
145148
print(f" DEBUG: {debug_mode}")
146149
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
@@ -202,6 +205,8 @@ def get_extensions():
202205
define_macros += [("WITH_HIP", None)]
203206
nvcc_flags = []
204207
extra_compile_args["nvcc"] = nvcc_flags
208+
elif torch.backends.mps.is_available() or force_mps:
209+
sources += source_mps
205210

206211
if sys.platform == "win32":
207212
define_macros += [("torchvision_EXPORTS", None)]

test/common_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
3535
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
3636
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
37+
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3738
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
3839

3940

@@ -130,12 +131,22 @@ def cpu_and_cuda():
130131
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
131132

132133

134+
def cpu_and_cuda_and_mps():
135+
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
136+
137+
133138
def needs_cuda(test_func):
134139
import pytest # noqa
135140

136141
return pytest.mark.needs_cuda(test_func)
137142

138143

144+
def needs_mps(test_func):
145+
import pytest # noqa
146+
147+
return pytest.mark.needs_mps(test_func)
148+
149+
139150
def _create_data(height=3, width=3, channels=3, device="cpu"):
140151
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
141152
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)

test/conftest.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@
88

99
torchvision.disable_beta_transforms_warning()
1010

11-
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
11+
from common_utils import (
12+
CUDA_NOT_AVAILABLE_MSG,
13+
IN_FBCODE,
14+
IN_OSS_CI,
15+
IN_RE_WORKER,
16+
MPS_NOT_AVAILABLE_MSG,
17+
OSS_CI_GPU_NO_CUDA_MSG,
18+
)
1219

1320

1421
def pytest_configure(config):
1522
# register an additional marker (see pytest_collection_modifyitems)
1623
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
24+
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
1725
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
1826

1927

@@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items):
3745
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
3846
# and the ones with device == 'cpu' won't have the mark.
3947
needs_cuda = item.get_closest_marker("needs_cuda") is not None
48+
needs_mps = item.get_closest_marker("needs_mps") is not None
4049

4150
if needs_cuda and not torch.cuda.is_available():
4251
# In general, we skip cuda tests on machines without a GPU
4352
# There are special cases though, see below
4453
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
4554

55+
if needs_mps and not torch.backends.mps.is_available():
56+
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))
57+
4658
if IN_FBCODE:
4759
# fbcode doesn't like skipping tests, so instead we just don't collect the test
4860
# so that they don't even "exist", hence the continue statements.
@@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items):
5466
# TODO: something more robust would be to do that only in a sandcastle instance,
5567
# so that we can still see the test being skipped when testing locally from a devvm
5668
continue
69+
if needs_mps and not torch.backends.mps.is_available():
70+
# Same as above, but for MPS
71+
continue
5772
elif IN_OSS_CI:
5873
# Here we're not in fbcode, so we can safely collect and skip tests.
5974
if not needs_cuda and torch.cuda.is_available():

test/test_ops.py

+84-25
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import torch.fx
1212
import torch.nn.functional as F
13-
from common_utils import assert_equal, cpu_and_cuda, needs_cuda
13+
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
1414
from PIL import Image
1515
from torch import nn, Tensor
1616
from torch.autograd import gradcheck
@@ -96,12 +96,33 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
9696

9797
class RoIOpTester(ABC):
9898
dtype = torch.float64
99+
mps_dtype = torch.float32
100+
mps_backward_atol = 2e-2
99101

100-
@pytest.mark.parametrize("device", cpu_and_cuda())
102+
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
101103
@pytest.mark.parametrize("contiguous", (True, False))
102-
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
103-
x_dtype = self.dtype if x_dtype is None else x_dtype
104-
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
104+
@pytest.mark.parametrize(
105+
"x_dtype",
106+
(
107+
torch.float16,
108+
torch.float32,
109+
torch.float64,
110+
),
111+
ids=str,
112+
)
113+
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
114+
if device == "mps" and x_dtype is torch.float64:
115+
pytest.skip("MPS does not support float64")
116+
117+
rois_dtype = x_dtype if rois_dtype is None else rois_dtype
118+
119+
tol = 1e-5
120+
if x_dtype is torch.half:
121+
if device == "mps":
122+
tol = 5e-3
123+
else:
124+
tol = 4e-3
125+
105126
pool_size = 5
106127
# n_channels % (pool_size ** 2) == 0 required for PS operations.
107128
n_channels = 2 * (pool_size**2)
@@ -120,10 +141,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ
120141
# the following should be true whether we're running an autocast test or not.
121142
assert y.dtype == x.dtype
122143
gt_y = self.expected_fn(
123-
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs
144+
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
124145
)
125146

126-
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
127147
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
128148

129149
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -155,16 +175,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
155175
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
156176

157177
@pytest.mark.parametrize("seed", range(10))
158-
@pytest.mark.parametrize("device", cpu_and_cuda())
178+
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
159179
@pytest.mark.parametrize("contiguous", (True, False))
160180
def test_backward(self, seed, device, contiguous, deterministic=False):
181+
atol = self.mps_backward_atol if device == "mps" else 1e-05
182+
dtype = self.mps_dtype if device == "mps" else self.dtype
183+
161184
torch.random.manual_seed(seed)
162185
pool_size = 2
163-
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
186+
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
164187
if not contiguous:
165188
x = x.permute(0, 1, 3, 2)
166189
rois = torch.tensor(
167-
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy)
190+
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy)
168191
)
169192

170193
def func(z):
@@ -173,9 +196,25 @@ def func(z):
173196
script_func = self.get_script_fn(rois, pool_size)
174197

175198
with DeterministicGuard(deterministic):
176-
gradcheck(func, (x,))
199+
gradcheck(func, (x,), atol=atol)
200+
201+
gradcheck(script_func, (x,), atol=atol)
177202

178-
gradcheck(script_func, (x,))
203+
@needs_mps
204+
def test_mps_error_inputs(self):
205+
pool_size = 2
206+
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
207+
rois = torch.tensor(
208+
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy)
209+
)
210+
211+
def func(z):
212+
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
213+
214+
with pytest.raises(
215+
RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
216+
):
217+
gradcheck(func, (x,))
179218

180219
@needs_cuda
181220
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@@ -271,6 +310,8 @@ def test_jit_boxes_list(self):
271310

272311

273312
class TestPSRoIPool(RoIOpTester):
313+
mps_backward_atol = 5e-2
314+
274315
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
275316
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
276317

@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):
352393

353394

354395
class TestRoIAlign(RoIOpTester):
396+
mps_backward_atol = 6e-2
397+
355398
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
356399
return ops.RoIAlign(
357400
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
@@ -418,10 +461,11 @@ def test_boxes_shape(self):
418461
self._helper_boxes_shape(ops.roi_align)
419462

420463
@pytest.mark.parametrize("aligned", (True, False))
421-
@pytest.mark.parametrize("device", cpu_and_cuda())
464+
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
465+
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
422466
@pytest.mark.parametrize("contiguous", (True, False))
423467
@pytest.mark.parametrize("deterministic", (True, False))
424-
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
468+
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
425469
if deterministic and device == "cpu":
426470
pytest.skip("cpu is always deterministic, don't retest")
427471
super().test_forward(
@@ -450,7 +494,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
450494
)
451495

452496
@pytest.mark.parametrize("seed", range(10))
453-
@pytest.mark.parametrize("device", cpu_and_cuda())
497+
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
454498
@pytest.mark.parametrize("contiguous", (True, False))
455499
@pytest.mark.parametrize("deterministic", (True, False))
456500
def test_backward(self, seed, device, contiguous, deterministic):
@@ -537,6 +581,8 @@ def test_jit_boxes_list(self):
537581

538582

539583
class TestPSRoIAlign(RoIOpTester):
584+
mps_backward_atol = 5e-2
585+
540586
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
541587
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
542588

@@ -705,40 +751,53 @@ def test_qnms(self, iou, scale, zero_point):
705751

706752
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
707753

708-
@needs_cuda
754+
@pytest.mark.parametrize(
755+
"device",
756+
(
757+
pytest.param("cuda", marks=pytest.mark.needs_cuda),
758+
pytest.param("mps", marks=pytest.mark.needs_mps),
759+
),
760+
)
709761
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
710-
def test_nms_cuda(self, iou, dtype=torch.float64):
762+
def test_nms_gpu(self, iou, device, dtype=torch.float64):
763+
dtype = torch.float32 if device == "mps" else dtype
711764
tol = 1e-3 if dtype is torch.half else 1e-5
712765
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
713766

714767
boxes, scores = self._create_tensors_with_iou(1000, iou)
715768
r_cpu = ops.nms(boxes, scores, iou)
716-
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
769+
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
717770

718-
is_eq = torch.allclose(r_cpu, r_cuda.cpu())
771+
is_eq = torch.allclose(r_cpu, r_gpu.cpu())
719772
if not is_eq:
720773
# if the indices are not the same, ensure that it's because the scores
721774
# are duplicate
722-
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
775+
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
723776
assert is_eq, err_msg.format(iou)
724777

725778
@needs_cuda
726779
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
727780
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
728781
def test_autocast(self, iou, dtype):
729782
with torch.cuda.amp.autocast():
730-
self.test_nms_cuda(iou=iou, dtype=dtype)
783+
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
731784

732-
@needs_cuda
733-
def test_nms_cuda_float16(self):
785+
@pytest.mark.parametrize(
786+
"device",
787+
(
788+
pytest.param("cuda", marks=pytest.mark.needs_cuda),
789+
pytest.param("mps", marks=pytest.mark.needs_mps),
790+
),
791+
)
792+
def test_nms_float16(self, device):
734793
boxes = torch.tensor(
735794
[
736795
[285.3538, 185.5758, 1193.5110, 851.4551],
737796
[285.1472, 188.7374, 1192.4984, 851.0669],
738797
[279.2440, 197.9812, 1189.4746, 849.2019],
739798
]
740-
).cuda()
741-
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
799+
).to(device)
800+
scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
742801

743802
iou_thres = 0.2
744803
keep32 = ops.nms(boxes, scores, iou_thres)

torchvision/csrc/ops/cpu/nms_kernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl(
1111
const at::Tensor& dets,
1212
const at::Tensor& scores,
1313
double iou_threshold) {
14-
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
15-
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
14+
TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
15+
TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
1616
TORCH_CHECK(
1717
dets.scalar_type() == scores.scalar_type(),
1818
"dets should have the same type as scores");
+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
constexpr int threadsPerBlock = 512;
2+
3+
template <typename T>
4+
constexpr inline T ceil_div(T n, T m) {
5+
return (n + m - 1) / m;
6+
}

0 commit comments

Comments
 (0)