Skip to content

Commit

Permalink
Bug Fix and performance optimized for rtc (apache#10018)
Browse files Browse the repository at this point in the history
* Bug Fix and performance optimized for rtc

1. "super().__init__()" bug is fixed in python 2.
2. Kernel is initialized in the stage of operator init.

* Update custom_softmax_rtc.py

fix unnessesary format
  • Loading branch information
chinakook authored and Jin Huang committed Mar 30, 2018
1 parent 065ecaf commit cd8deff
Showing 1 changed file with 72 additions and 59 deletions.
131 changes: 72 additions & 59 deletions example/numpy-ops/custom_softmax_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,78 +23,91 @@

class Softmax(mx.operator.CustomOp):
def __init__(self):
self.fwd_kernel_mod = None
self.bwd_kernel_mod = None
super().__init__()
super(Softmax,self).__init__()
# Each thread processes a row (a sample in the batch).
fwd_src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""

# Each block processes a row and each thread in a block calculate an element of `dx`.
bwd_src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd<float>", "fwd<double>"])
bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd<float>", "bwd<double>"])

fwd_kernel_float_signature = "const float*, const float*, const int, const int"
self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<float>", fwd_kernel_float_signature)

bwd_kernel_float_signature = "const float*, const float*, float*, const int"
self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<float>", bwd_kernel_float_signature)

fwd_kernel_double_signature = "const double*, const double*, const int, const int"
self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<double>", fwd_kernel_double_signature)

bwd_kernel_double_signature = "const double*, const double*, double*, const int"
self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<double>", bwd_kernel_double_signature)

def forward(self, is_train, req, in_data, out_data, aux):
if req[0] == "null":
return
x = in_data[0] # input
y = out_data[0] # output
if self.fwd_kernel_mod is None:
# Each thread processes a row (a sample in the batch).
src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""
self.fwd_kernel_mod = mx.rtc.CudaModule(src, exports=["fwd<float>", "fwd<double>"])
dtype = "double" if y.dtype == np.float64 else "float"
kernel_signature = "const {0}*, const {0}*, const int, const int".format(dtype)
kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), kernel_signature)
# args, ctx, grid_shape, block_shape, shared_mem = 0
kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))

if y.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
if req[0] == "null":
return
l = in_data[1] # label
y = out_data[0] # output from the forward pass
dx = in_grad[0] # the storage for the gradient
if self.bwd_kernel_mod is None:
# Each block processes a row and each thread in a block calculate an element of `dx`.
src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
self.bwd_kernel_mod = mx.rtc.CudaModule(src, exports=["bwd<float>", "bwd<double>"])
dtype = "double" if dx.dtype == np.float64 else "float"
kernel_signature = "const {0}*, const {0}*, {0}*, const int".format(dtype)
kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), kernel_signature)
# args, ctx, grid_shape, block_shape, shared_mem = 0
kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))

if dx.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))

def _reqCode(self, req):
if(req == "write"):
Expand Down

0 comments on commit cd8deff

Please sign in to comment.