diff --git a/src/kernels/10_kernel_warptiling.cuh b/src/kernels/10_kernel_warptiling.cuh index 1fcb141..e3ec0cc 100644 --- a/src/kernels/10_kernel_warptiling.cuh +++ b/src/kernels/10_kernel_warptiling.cuh @@ -112,8 +112,8 @@ __global__ void __launch_bounds__(NUM_THREADS) const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4 // allocate space for the current blocktile in SMEM - __shared__ float As[BM * BK]; - __shared__ float Bs[BK * BN]; + __shared__ float As[2 * BM * BK]; + __shared__ float Bs[2 * BK * BN]; // Move blocktile to beginning of A's row and B's column A += cRow * BM * K; @@ -136,19 +136,44 @@ __global__ void __launch_bounds__(NUM_THREADS) float regM[WMITER * TM] = {0.0}; float regN[WNITER * TN] = {0.0}; + int As_offset = 0; + int Bs_offset = 0; + + // double-buffering: load first blocktile into SMEM + wt::loadFromGmem( + N, K, A, B, As + As_offset * BM * BK, Bs + Bs_offset * BK * BN, innerRowA, + innerColA, innerRowB, innerColB); + + __syncthreads(); + // outer-most loop over block tiles - for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + for (uint bkIdx = 0; bkIdx < K - BK; bkIdx += BK) { + // double-buffering: load next blocktile into SMEM wt::loadFromGmem( - N, K, A, B, As, Bs, innerRowA, innerColA, innerRowB, innerColB); - __syncthreads(); + N, K, A + BK, B + BK * N, As + (1 - As_offset) * BM * BK, + Bs + (1 - Bs_offset) * BK * BN, innerRowA, innerColA, innerRowB, + innerColB); + + // compute the current blocktile wt::processFromSmem(regM, regN, threadResults, As, Bs, warpRow, warpCol, + TN>(regM, regN, threadResults, As + As_offset * BM * BK, + Bs + Bs_offset * BK * BN, warpRow, warpCol, threadRowInWarp, threadColInWarp); A += BK; // move BK columns to right B += BK * N; // move BK rows down + __syncthreads(); + + As_offset = 1 - As_offset; + Bs_offset = 1 - Bs_offset; } + // compute the last blocktile + wt::processFromSmem( + regM, regN, threadResults, As + As_offset * BM * BK, + Bs + Bs_offset * BK * BN, warpRow, warpCol, threadRowInWarp, + threadColInWarp); + // write out the results for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {