@@ -5,7 +5,7 @@ namespace ops {
5
5
6
6
namespace mps {
7
7
8
- static const char * METAL_VISION = R"VISION_METAL(
8
+ static at::native::mps::MetalShaderLibrary lib ( R"VISION_METAL(
9
9
10
10
#include <metal_atomic>
11
11
#include <metal_stdlib>
@@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) {
26
26
return (n + m - 1) / m;
27
27
}
28
28
29
- template <typename T>
30
- inline void atomic_add_float( device T* data_ptr, const T val)
29
+ inline void atomic_add_float(device float* data_ptr, const float val)
31
30
{
32
- #if __METAL_VERSION__ >= 300
33
- // atomic_float is supported in Metal 3 (macOS Ventura) onward.
34
- device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
35
- #else
36
- // Custom atomic addition implementation
37
- // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
38
- // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
39
- // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
40
-
41
- // Create an atomic uint pointer for atomic transaction.
42
- device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
43
- // Create necessary storage.
44
- uint fetched_uint, assigning_uint;
45
- T fetched_float, assigning_float;
46
-
47
- // Replace the value in atom_var with 0 and return the previous value in atom_var.
48
- fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
49
- // Read out the previous value as float.
50
- fetched_float = *( (thread T*) &fetched_uint );
51
-
52
- // Do addition and represent the addition result in uint for atomic transaction.
53
- assigning_float = fetched_float + val;
54
- assigning_uint = *((thread uint*) &assigning_float);
55
-
56
- // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
57
- while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
58
- // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
59
- // Try to assign 0 and get the previously assigned addition result.
60
- uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
61
- T fetched_float_again = *( (thread T*) &fetched_uint_again );
62
- // Re-add again
63
- fetched_float = *((thread T*) &(fetched_uint));
64
- // Previously assigned addition result + addition result from other threads.
65
- assigning_float = fetched_float_again + fetched_float;
66
- assigning_uint = *( (thread uint*) &assigning_float);
67
- }
68
- #endif
31
+ atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
32
+ }
33
+
34
+
35
+ inline void atomic_add_float(device half* data_ptr, const half val)
36
+ {
37
+ atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast<float>(val), memory_order_relaxed);
69
38
}
70
39
71
40
template <typename T, typename integer_t>
@@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t);
1061
1030
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
1062
1031
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
1063
1032
1064
- )VISION_METAL" ;
1065
-
1066
- static id<MTLLibrary> compileVisionOpsLibrary (id<MTLDevice> device) {
1067
- static id<MTLLibrary> visionLibrary = nil;
1068
- if (visionLibrary) {
1069
- return visionLibrary;
1070
- }
1071
-
1072
- NSError* error = nil;
1073
- MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
1074
- [options setLanguageVersion:MTLLanguageVersion2_3];
1075
- visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
1076
- options:options
1077
- error:&error];
1078
- TORCH_CHECK (visionLibrary, " Failed to create metal vision library, error: " , [[error description] UTF8String]);
1079
- return visionLibrary;
1080
- }
1081
-
1082
- static id<MTLComputePipelineState> visionPipelineState (id<MTLDevice> device, const std::string& kernel) {
1083
- static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
1084
- id<MTLComputePipelineState> pso = psoCache[kernel];
1085
- if (pso) {
1086
- return pso;
1087
- }
1088
-
1089
- NSError* error = nil;
1090
- id<MTLLibrary> visionLib = compileVisionOpsLibrary (device);
1091
- id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str ()]];
1092
- TORCH_CHECK (visionFunc, " Failed to create function state object for: " , kernel);
1093
- pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
1094
- TORCH_CHECK (pso, " Failed to created pipeline state object, error: " , [[error description] UTF8String]);
1033
+ )VISION_METAL" );
1095
1034
1096
- psoCache[kernel] = pso;
1097
- return pso;
1035
+ static id<MTLComputePipelineState> visionPipelineState (
1036
+ id<MTLDevice> device,
1037
+ const std::string& kernel) {
1038
+ return lib.getPipelineStateForFunc (kernel);
1098
1039
}
1099
1040
1100
1041
} // namespace mps
0 commit comments