Skip to content

Commit

Permalink
K6,K9: Make sync non-conditional again
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Boehm committed Mar 4, 2023
1 parent fd03f69 commit 1bc1353
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
7 changes: 2 additions & 5 deletions src/kernels/6_kernel_vectorize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ __global__ void sgemmVectorize(int M, int N, int K, float alpha, float *A,
float regM[TM] = {0.0};
float regN[TN] = {0.0};

// outer-most loop over block tiles
#pragma unroll
// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
if (bkIdx != 0) {
__syncthreads();
}
// populate the SMEM caches
// transpose A while loading it
float4 tmp =
Expand Down Expand Up @@ -79,6 +75,7 @@ __global__ void sgemmVectorize(int M, int N, int K, float alpha, float *A,
}
}
}
__syncthreads();
}

// write out the results
Expand Down
4 changes: 1 addition & 3 deletions src/kernels/9_kernel_autotuned.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ __global__ void __launch_bounds__(K9_NUM_THREADS)
// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// populate the SMEM caches
if (bkIdx != 0) {
__syncthreads();
}
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
float4 tmp = reinterpret_cast<float4 *>(
&A[(innerRowA + offset) * K + innerColA * 4])[0];
Expand Down Expand Up @@ -96,6 +93,7 @@ __global__ void __launch_bounds__(K9_NUM_THREADS)
}
}
}
__syncthreads();
// advance blocktile
A += BK; // move BK columns to right
B += BK * N; // move BK rows down
Expand Down

0 comments on commit 1bc1353

Please sign in to comment.