|
23 | 23 |
|
24 | 24 | class Softmax(mx.operator.CustomOp):
|
25 | 25 | def __init__(self):
|
26 |
| - self.fwd_kernel_mod = None |
27 |
| - self.bwd_kernel_mod = None |
28 |
| - super().__init__() |
| 26 | + super(Softmax,self).__init__() |
| 27 | + # Each thread processes a row (a sample in the batch). |
| 28 | + fwd_src = r""" |
| 29 | + template<class DType> |
| 30 | + __global__ void fwd(const DType* x, DType* y, const int row_size, const int req) { |
| 31 | + const int offset = row_size * threadIdx.x; |
| 32 | + DType max = x[offset]; |
| 33 | + for(int i = 1; i < row_size; ++i) { |
| 34 | + if(max < x[offset + i]) { |
| 35 | + max = x[offset + i]; |
| 36 | + } |
| 37 | + } |
| 38 | + DType sum = 0; |
| 39 | + for(int i = 0; i < row_size; ++i) { |
| 40 | + sum += exp(x[offset + i] - max); |
| 41 | + } |
| 42 | + switch(req) { |
| 43 | + case 1: |
| 44 | + for(int i = 0; i < row_size; ++i) { |
| 45 | + y[offset + i] = exp(x[offset + i] - max) / sum; |
| 46 | + } |
| 47 | + break; |
| 48 | + case 2: |
| 49 | + for(int i = 0; i < row_size; ++i) { |
| 50 | + y[offset + i] += exp(x[offset + i] - max) / sum; |
| 51 | + } |
| 52 | + break; |
| 53 | + } |
| 54 | + } |
| 55 | + """ |
| 56 | + |
| 57 | + # Each block processes a row and each thread in a block calculate an element of `dx`. |
| 58 | + bwd_src = r""" |
| 59 | + template<class DType> |
| 60 | + __global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) { |
| 61 | + const int z = static_cast<int>(l[blockIdx.x]); |
| 62 | + const int i = threadIdx.x + blockDim.x * blockIdx.x; |
| 63 | + if(req == 1) { |
| 64 | + dx[i] = threadIdx.x == z ? y[i] - 1 : y[i]; |
| 65 | + } else { |
| 66 | + dx[i] += threadIdx.x == z ? y[i] - 1 : y[i]; |
| 67 | + } |
| 68 | + } |
| 69 | + """ |
| 70 | + fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd<float>", "fwd<double>"]) |
| 71 | + bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd<float>", "bwd<double>"]) |
| 72 | + |
| 73 | + fwd_kernel_float_signature = "const float*, const float*, const int, const int" |
| 74 | + self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<float>", fwd_kernel_float_signature) |
| 75 | + |
| 76 | + bwd_kernel_float_signature = "const float*, const float*, float*, const int" |
| 77 | + self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<float>", bwd_kernel_float_signature) |
| 78 | + |
| 79 | + fwd_kernel_double_signature = "const double*, const double*, const int, const int" |
| 80 | + self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<double>", fwd_kernel_double_signature) |
| 81 | + |
| 82 | + bwd_kernel_double_signature = "const double*, const double*, double*, const int" |
| 83 | + self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<double>", bwd_kernel_double_signature) |
29 | 84 |
|
30 | 85 | def forward(self, is_train, req, in_data, out_data, aux):
|
31 | 86 | if req[0] == "null":
|
32 | 87 | return
|
33 | 88 | x = in_data[0] # input
|
34 | 89 | y = out_data[0] # output
|
35 |
| - if self.fwd_kernel_mod is None: |
36 |
| - # Each thread processes a row (a sample in the batch). |
37 |
| - src = r""" |
38 |
| - template<class DType> |
39 |
| - __global__ void fwd(const DType* x, DType* y, const int row_size, const int req) { |
40 |
| - const int offset = row_size * threadIdx.x; |
41 |
| - DType max = x[offset]; |
42 |
| - for(int i = 1; i < row_size; ++i) { |
43 |
| - if(max < x[offset + i]) { |
44 |
| - max = x[offset + i]; |
45 |
| - } |
46 |
| - } |
47 |
| - DType sum = 0; |
48 |
| - for(int i = 0; i < row_size; ++i) { |
49 |
| - sum += exp(x[offset + i] - max); |
50 |
| - } |
51 |
| - switch(req) { |
52 |
| - case 1: |
53 |
| - for(int i = 0; i < row_size; ++i) { |
54 |
| - y[offset + i] = exp(x[offset + i] - max) / sum; |
55 |
| - } |
56 |
| - break; |
57 |
| - case 2: |
58 |
| - for(int i = 0; i < row_size; ++i) { |
59 |
| - y[offset + i] += exp(x[offset + i] - max) / sum; |
60 |
| - } |
61 |
| - break; |
62 |
| - } |
63 |
| - } |
64 |
| - """ |
65 |
| - self.fwd_kernel_mod = mx.rtc.CudaModule(src, exports=["fwd<float>", "fwd<double>"]) |
66 |
| - dtype = "double" if y.dtype == np.float64 else "float" |
67 |
| - kernel_signature = "const {0}*, const {0}*, const int, const int".format(dtype) |
68 |
| - kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), kernel_signature) |
69 |
| - # args, ctx, grid_shape, block_shape, shared_mem = 0 |
70 |
| - kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1)) |
| 90 | + |
| 91 | + if y.dtype == np.float64: |
| 92 | + # args, ctx, grid_shape, block_shape, shared_mem = 0 |
| 93 | + self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1)) |
| 94 | + else: |
| 95 | + # args, ctx, grid_shape, block_shape, shared_mem = 0 |
| 96 | + self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1)) |
71 | 97 |
|
72 | 98 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
|
73 | 99 | if req[0] == "null":
|
74 | 100 | return
|
75 | 101 | l = in_data[1] # label
|
76 | 102 | y = out_data[0] # output from the forward pass
|
77 | 103 | dx = in_grad[0] # the storage for the gradient
|
78 |
| - if self.bwd_kernel_mod is None: |
79 |
| - # Each block processes a row and each thread in a block calculate an element of `dx`. |
80 |
| - src = r""" |
81 |
| - template<class DType> |
82 |
| - __global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) { |
83 |
| - const int z = static_cast<int>(l[blockIdx.x]); |
84 |
| - const int i = threadIdx.x + blockDim.x * blockIdx.x; |
85 |
| - if(req == 1) { |
86 |
| - dx[i] = threadIdx.x == z ? y[i] - 1 : y[i]; |
87 |
| - } else { |
88 |
| - dx[i] += threadIdx.x == z ? y[i] - 1 : y[i]; |
89 |
| - } |
90 |
| - } |
91 |
| - """ |
92 |
| - self.bwd_kernel_mod = mx.rtc.CudaModule(src, exports=["bwd<float>", "bwd<double>"]) |
93 |
| - dtype = "double" if dx.dtype == np.float64 else "float" |
94 |
| - kernel_signature = "const {0}*, const {0}*, {0}*, const int".format(dtype) |
95 |
| - kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), kernel_signature) |
96 |
| - # args, ctx, grid_shape, block_shape, shared_mem = 0 |
97 |
| - kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1)) |
| 104 | + |
| 105 | + if dx.dtype == np.float64: |
| 106 | + # args, ctx, grid_shape, block_shape, shared_mem = 0 |
| 107 | + self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1)) |
| 108 | + else: |
| 109 | + # args, ctx, grid_shape, block_shape, shared_mem = 0 |
| 110 | + self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1)) |
98 | 111 |
|
99 | 112 | def _reqCode(self, req):
|
100 | 113 | if(req == "write"):
|
|
0 commit comments