Skip to content

Commit

Permalink
Update softmax.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
linjames0 authored Sep 20, 2023
1 parent 3dcc52c commit f5dfcc4
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
__global__ void softmax(float *d_in, float *d_out, float *expArr, float *redArr, int N) {
int col = blockIdx.x * blockDim.x + threadIdx.x;

if(col < N) {
if (col < N) {
float local_exp = expf(d_in[col]);
expArr[col] = local_exp;
redArr[col] = expArr[col];
Expand All @@ -24,19 +24,18 @@ __global__ void softmax(float *d_in, float *d_out, float *expArr, float *redArr,
__syncthreads();

// parallel reduction to compute sum
for(int stride = 1 << padding; stride >= 1; stride /= 2) {
if(col < stride) {
for (int stride = 1 << padding; stride >= 1; stride /= 2) {
if (col < stride) {
redArr[col] += redArr[col + stride];
}
}
__syncthreads();
}

// calculate e^x / sum(e^x) = softmax for each element
if(col == 0) {
if (col < N) {
float sum = redArr[0];
for(int i = 0; i < N; ++i) {
d_out[i] = expArr[i] / sum;
}
d_out[col] = expArr[col] / sum;
}
}

Expand All @@ -56,7 +55,7 @@ int main() {
cudaMalloc((void**)&redArr, 2 * N * sizeof(float));

// data initialization
for(int i = 0; i < N; ++i) {
for (int i = 0; i < N; ++i) {
h_in[i] = (float)(rand() % 5 + 1);
}

Expand Down

0 comments on commit f5dfcc4

Please sign in to comment.