Skip to content

Commit cb9fdbf

Browse files
authored
[MPS] Lift MSL version to 3.0+ and use relevant helpers (#8719)
Summary: 1. Remove the custom atomic add function and use the one provided by MSL 3.0+ instead. 2. Use `MetalShaderLibrary` class.
1 parent 66c5629 commit cb9fdbf

File tree

2 files changed

+14
-74
lines changed

2 files changed

+14
-74
lines changed

torchvision/csrc/ops/mps/mps_kernels.h

+14-73
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace ops {
55

66
namespace mps {
77

8-
static const char* METAL_VISION = R"VISION_METAL(
8+
static at::native::mps::MetalShaderLibrary lib(R"VISION_METAL(
99
1010
#include <metal_atomic>
1111
#include <metal_stdlib>
@@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) {
2626
return (n + m - 1) / m;
2727
}
2828
29-
template <typename T>
30-
inline void atomic_add_float( device T* data_ptr, const T val)
29+
inline void atomic_add_float(device float* data_ptr, const float val)
3130
{
32-
#if __METAL_VERSION__ >= 300
33-
// atomic_float is supported in Metal 3 (macOS Ventura) onward.
34-
device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
35-
#else
36-
// Custom atomic addition implementation
37-
// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
38-
// https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
39-
// https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
40-
41-
// Create an atomic uint pointer for atomic transaction.
42-
device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
43-
// Create necessary storage.
44-
uint fetched_uint, assigning_uint;
45-
T fetched_float, assigning_float;
46-
47-
// Replace the value in atom_var with 0 and return the previous value in atom_var.
48-
fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
49-
// Read out the previous value as float.
50-
fetched_float = *( (thread T*) &fetched_uint );
51-
52-
// Do addition and represent the addition result in uint for atomic transaction.
53-
assigning_float = fetched_float + val;
54-
assigning_uint = *((thread uint*) &assigning_float);
55-
56-
// atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
57-
while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
58-
// If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
59-
// Try to assign 0 and get the previously assigned addition result.
60-
uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
61-
T fetched_float_again = *( (thread T*) &fetched_uint_again );
62-
// Re-add again
63-
fetched_float = *((thread T*) &(fetched_uint));
64-
// Previously assigned addition result + addition result from other threads.
65-
assigning_float = fetched_float_again + fetched_float;
66-
assigning_uint = *( (thread uint*) &assigning_float);
67-
}
68-
#endif
31+
atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
32+
}
33+
34+
35+
inline void atomic_add_float(device half* data_ptr, const half val)
36+
{
37+
atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast<float>(val), memory_order_relaxed);
6938
}
7039
7140
template <typename T, typename integer_t>
@@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t);
10611030
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
10621031
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
10631032
1064-
)VISION_METAL";
1065-
1066-
static id<MTLLibrary> compileVisionOpsLibrary(id<MTLDevice> device) {
1067-
static id<MTLLibrary> visionLibrary = nil;
1068-
if (visionLibrary) {
1069-
return visionLibrary;
1070-
}
1071-
1072-
NSError* error = nil;
1073-
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
1074-
[options setLanguageVersion:MTLLanguageVersion2_3];
1075-
visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
1076-
options:options
1077-
error:&error];
1078-
TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]);
1079-
return visionLibrary;
1080-
}
1081-
1082-
static id<MTLComputePipelineState> visionPipelineState(id<MTLDevice> device, const std::string& kernel) {
1083-
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
1084-
id<MTLComputePipelineState> pso = psoCache[kernel];
1085-
if (pso) {
1086-
return pso;
1087-
}
1088-
1089-
NSError* error = nil;
1090-
id<MTLLibrary> visionLib = compileVisionOpsLibrary(device);
1091-
id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
1092-
TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel);
1093-
pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
1094-
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
1033+
)VISION_METAL");
10951034

1096-
psoCache[kernel] = pso;
1097-
return pso;
1035+
static id<MTLComputePipelineState> visionPipelineState(
1036+
id<MTLDevice> device,
1037+
const std::string& kernel) {
1038+
return lib.getPipelineStateForFunc(kernel);
10981039
}
10991040

11001041
} // namespace mps

torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm

-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@
123123

124124
float spatial_scale_f = static_cast<float>(spatial_scale);
125125

126-
auto num_rois = rois.size(0);
127126
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
128127

129128
if (grad.numel() == 0) {

0 commit comments

Comments
 (0)