Skip to content

Commit cd8deff

Browse files
chinakookJin Huang
authored and
Jin Huang
committed
Bug Fix and performance optimized for rtc (apache#10018)
* Bug Fix and performance optimized for rtc 1. "super().__init__()" bug is fixed in python 2. 2. Kernel is initialized in the stage of operator init. * Update custom_softmax_rtc.py fix unnessesary format
1 parent 065ecaf commit cd8deff

File tree

1 file changed

+72
-59
lines changed

1 file changed

+72
-59
lines changed

example/numpy-ops/custom_softmax_rtc.py

+72-59
Original file line numberDiff line numberDiff line change
@@ -23,78 +23,91 @@
2323

2424
class Softmax(mx.operator.CustomOp):
2525
def __init__(self):
26-
self.fwd_kernel_mod = None
27-
self.bwd_kernel_mod = None
28-
super().__init__()
26+
super(Softmax,self).__init__()
27+
# Each thread processes a row (a sample in the batch).
28+
fwd_src = r"""
29+
template<class DType>
30+
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
31+
const int offset = row_size * threadIdx.x;
32+
DType max = x[offset];
33+
for(int i = 1; i < row_size; ++i) {
34+
if(max < x[offset + i]) {
35+
max = x[offset + i];
36+
}
37+
}
38+
DType sum = 0;
39+
for(int i = 0; i < row_size; ++i) {
40+
sum += exp(x[offset + i] - max);
41+
}
42+
switch(req) {
43+
case 1:
44+
for(int i = 0; i < row_size; ++i) {
45+
y[offset + i] = exp(x[offset + i] - max) / sum;
46+
}
47+
break;
48+
case 2:
49+
for(int i = 0; i < row_size; ++i) {
50+
y[offset + i] += exp(x[offset + i] - max) / sum;
51+
}
52+
break;
53+
}
54+
}
55+
"""
56+
57+
# Each block processes a row and each thread in a block calculate an element of `dx`.
58+
bwd_src = r"""
59+
template<class DType>
60+
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
61+
const int z = static_cast<int>(l[blockIdx.x]);
62+
const int i = threadIdx.x + blockDim.x * blockIdx.x;
63+
if(req == 1) {
64+
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
65+
} else {
66+
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
67+
}
68+
}
69+
"""
70+
fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd<float>", "fwd<double>"])
71+
bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd<float>", "bwd<double>"])
72+
73+
fwd_kernel_float_signature = "const float*, const float*, const int, const int"
74+
self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<float>", fwd_kernel_float_signature)
75+
76+
bwd_kernel_float_signature = "const float*, const float*, float*, const int"
77+
self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<float>", bwd_kernel_float_signature)
78+
79+
fwd_kernel_double_signature = "const double*, const double*, const int, const int"
80+
self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<double>", fwd_kernel_double_signature)
81+
82+
bwd_kernel_double_signature = "const double*, const double*, double*, const int"
83+
self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<double>", bwd_kernel_double_signature)
2984

3085
def forward(self, is_train, req, in_data, out_data, aux):
3186
if req[0] == "null":
3287
return
3388
x = in_data[0] # input
3489
y = out_data[0] # output
35-
if self.fwd_kernel_mod is None:
36-
# Each thread processes a row (a sample in the batch).
37-
src = r"""
38-
template<class DType>
39-
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
40-
const int offset = row_size * threadIdx.x;
41-
DType max = x[offset];
42-
for(int i = 1; i < row_size; ++i) {
43-
if(max < x[offset + i]) {
44-
max = x[offset + i];
45-
}
46-
}
47-
DType sum = 0;
48-
for(int i = 0; i < row_size; ++i) {
49-
sum += exp(x[offset + i] - max);
50-
}
51-
switch(req) {
52-
case 1:
53-
for(int i = 0; i < row_size; ++i) {
54-
y[offset + i] = exp(x[offset + i] - max) / sum;
55-
}
56-
break;
57-
case 2:
58-
for(int i = 0; i < row_size; ++i) {
59-
y[offset + i] += exp(x[offset + i] - max) / sum;
60-
}
61-
break;
62-
}
63-
}
64-
"""
65-
self.fwd_kernel_mod = mx.rtc.CudaModule(src, exports=["fwd<float>", "fwd<double>"])
66-
dtype = "double" if y.dtype == np.float64 else "float"
67-
kernel_signature = "const {0}*, const {0}*, const int, const int".format(dtype)
68-
kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), kernel_signature)
69-
# args, ctx, grid_shape, block_shape, shared_mem = 0
70-
kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
90+
91+
if y.dtype == np.float64:
92+
# args, ctx, grid_shape, block_shape, shared_mem = 0
93+
self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
94+
else:
95+
# args, ctx, grid_shape, block_shape, shared_mem = 0
96+
self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
7197

7298
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
7399
if req[0] == "null":
74100
return
75101
l = in_data[1] # label
76102
y = out_data[0] # output from the forward pass
77103
dx = in_grad[0] # the storage for the gradient
78-
if self.bwd_kernel_mod is None:
79-
# Each block processes a row and each thread in a block calculate an element of `dx`.
80-
src = r"""
81-
template<class DType>
82-
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
83-
const int z = static_cast<int>(l[blockIdx.x]);
84-
const int i = threadIdx.x + blockDim.x * blockIdx.x;
85-
if(req == 1) {
86-
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
87-
} else {
88-
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
89-
}
90-
}
91-
"""
92-
self.bwd_kernel_mod = mx.rtc.CudaModule(src, exports=["bwd<float>", "bwd<double>"])
93-
dtype = "double" if dx.dtype == np.float64 else "float"
94-
kernel_signature = "const {0}*, const {0}*, {0}*, const int".format(dtype)
95-
kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), kernel_signature)
96-
# args, ctx, grid_shape, block_shape, shared_mem = 0
97-
kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
104+
105+
if dx.dtype == np.float64:
106+
# args, ctx, grid_shape, block_shape, shared_mem = 0
107+
self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
108+
else:
109+
# args, ctx, grid_shape, block_shape, shared_mem = 0
110+
self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
98111

99112
def _reqCode(self, req):
100113
if(req == "write"):

0 commit comments

Comments
 (0)