Skip to content

Commit

Permalink
commit working (but slower) double buffering version of K10
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Boehm committed Mar 4, 2023
1 parent f6e8446 commit 0184e95
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions src/kernels/10_kernel_warptiling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ __global__ void __launch_bounds__(NUM_THREADS)
const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4

// allocate space for the current blocktile in SMEM
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];
__shared__ float As[2 * BM * BK];
__shared__ float Bs[2 * BK * BN];

// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
Expand All @@ -136,19 +136,44 @@ __global__ void __launch_bounds__(NUM_THREADS)
float regM[WMITER * TM] = {0.0};
float regN[WNITER * TN] = {0.0};

int As_offset = 0;
int Bs_offset = 0;

// double-buffering: load first blocktile into SMEM
wt::loadFromGmem<BM, BN, BK, rowStrideA, rowStrideB>(
N, K, A, B, As + As_offset * BM * BK, Bs + Bs_offset * BK * BN, innerRowA,
innerColA, innerRowB, innerColB);

__syncthreads();

// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
for (uint bkIdx = 0; bkIdx < K - BK; bkIdx += BK) {
// double-buffering: load next blocktile into SMEM
wt::loadFromGmem<BM, BN, BK, rowStrideA, rowStrideB>(
N, K, A, B, As, Bs, innerRowA, innerColA, innerRowB, innerColB);
__syncthreads();
N, K, A + BK, B + BK * N, As + (1 - As_offset) * BM * BK,
Bs + (1 - Bs_offset) * BK * BN, innerRowA, innerColA, innerRowB,
innerColB);

// compute the current blocktile
wt::processFromSmem<BM, BN, BK, WM, WN, WMITER, WNITER, WSUBM, WSUBN, TM,
TN>(regM, regN, threadResults, As, Bs, warpRow, warpCol,
TN>(regM, regN, threadResults, As + As_offset * BM * BK,
Bs + Bs_offset * BK * BN, warpRow, warpCol,
threadRowInWarp, threadColInWarp);
A += BK; // move BK columns to right
B += BK * N; // move BK rows down

__syncthreads();

As_offset = 1 - As_offset;
Bs_offset = 1 - Bs_offset;
}

// compute the last blocktile
wt::processFromSmem<BM, BN, BK, WM, WN, WMITER, WNITER, WSUBM, WSUBN, TM, TN>(
regM, regN, threadResults, As + As_offset * BM * BK,
Bs + Bs_offset * BK * BN, warpRow, warpCol, threadRowInWarp,
threadColInWarp);

// write out the results
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
Expand Down

0 comments on commit 0184e95

Please sign in to comment.