Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Bug Fix and performance optimized for rtc #10018

Merged
merged 2 commits into from
Mar 9, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 72 additions & 59 deletions example/numpy-ops/custom_softmax_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,78 +23,91 @@

class Softmax(mx.operator.CustomOp):
def __init__(self):
self.fwd_kernel_mod = None
self.bwd_kernel_mod = None
super().__init__()
super(Softmax,self).__init__()
# Each thread processes a row (a sample in the batch).
fwd_src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""

# Each block processes a row and each thread in a block calculate an element of `dx`.
bwd_src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd<float>", "fwd<double>"])
bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd<float>", "bwd<double>"])

fwd_kernel_float_signature = "const float*, const float*, const int, const int"
self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<float>", fwd_kernel_float_signature)

bwd_kernel_float_signature = "const float*, const float*, float*, const int"
self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<float>", bwd_kernel_float_signature)

fwd_kernel_double_signature = "const double*, const double*, const int, const int"
self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<double>", fwd_kernel_double_signature)

bwd_kernel_double_signature = "const double*, const double*, double*, const int"
self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<double>", bwd_kernel_double_signature)

def forward(self, is_train, req, in_data, out_data, aux):
if req[0] == "null":
return
x = in_data[0] # input
y = out_data[0] # output
if self.fwd_kernel_mod is None:
# Each thread processes a row (a sample in the batch).
src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""
self.fwd_kernel_mod = mx.rtc.CudaModule(src, exports=["fwd<float>", "fwd<double>"])
dtype = "double" if y.dtype == np.float64 else "float"
kernel_signature = "const {0}*, const {0}*, const int, const int".format(dtype)
kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), kernel_signature)
# args, ctx, grid_shape, block_shape, shared_mem = 0
kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))

if y.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
if req[0] == "null":
return
l = in_data[1] # label
y = out_data[0] # output from the forward pass
dx = in_grad[0] # the storage for the gradient
if self.bwd_kernel_mod is None:
# Each block processes a row and each thread in a block calculate an element of `dx`.
src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
self.bwd_kernel_mod = mx.rtc.CudaModule(src, exports=["bwd<float>", "bwd<double>"])
dtype = "double" if dx.dtype == np.float64 else "float"
kernel_signature = "const {0}*, const {0}*, {0}*, const int".format(dtype)
kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), kernel_signature)
# args, ctx, grid_shape, block_shape, shared_mem = 0
kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))

if dx.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))

def _reqCode(self, req):
if(req == "write"):
Expand Down