Skip to content

Commit

Permalink
Add skeleton for double buffering
Browse files Browse the repository at this point in the history
  • Loading branch information
siboehm committed Feb 26, 2023
1 parent c09b350 commit c990b12
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "kernels/10_kernel_warptiling.cuh"
#include "kernels/11_kernel_double_buffering.cuh"
#include "kernels/1_naive.cuh"
#include "kernels/2_kernel_global_mem_coalesce.cuh"
#include "kernels/3_kernel_shared_mem_blocking.cuh"
Expand Down
154 changes: 154 additions & 0 deletions src/kernels/11_kernel_double_buffering.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#pragma once

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

/*
* @tparam BM The threadblock size for M dimension SMEM caching.
* @tparam BN The threadblock size for N dimension SMEM caching.
* @tparam BK The threadblock size for K dimension SMEM caching.
* @tparam WM M dim of continuous tile computed by each warp
* @tparam WN N dim of continuous tile computed by each warp
* @tparam WMITER The number of subwarp tiling steps in M dimension.
* @tparam WNITER The number of subwarp tiling steps in N dimension.
* @tparam TM The per-thread tile size for M dimension.
* @tparam TN The per-thread tile size for N dimension.
*/
template <const int BM, const int BN, const int BK, const int WM, const int WN,
const int WNITER, const int TM, const int TN, const int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
sgemmDoubleBuffering(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;

// Placement of the warp in the threadblock tile
const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in
const uint warpCol = warpIdx % (BN / WN);
const uint warpRow = warpIdx / (BN / WN);

// size of the warp subtile
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
constexpr uint WSUBM = WM / WMITER; // 64/2=32
constexpr uint WSUBN = WN / WNITER; // 32/2=16

// Placement of the thread in the warp subtile
const uint threadIdxInWarp = threadIdx.x % WARPSIZE; // [0, 31]
const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4)
const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4

// allocate space for the current blocktile in SMEM
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];

// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
// Move C_ptr to warp's output tile
C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN;

// calculating the indices that this thread will load into SMEM
// we'll load 128bit / 32bit = 4 elements per thread at each step
const uint innerRowA = threadIdx.x / (BK / 4);
const uint innerColA = threadIdx.x % (BK / 4);
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
const uint innerRowB = threadIdx.x / (BN / 4);
const uint innerColB = threadIdx.x % (BN / 4);
constexpr uint rowStrideB = NUM_THREADS / (BN / 4);

// allocate thread-local cache for results in registerfile
float threadResults[WMITER * TM * WNITER * TN] = {0.0};
// we cache into registers on the warptile level
float regM[WMITER * TM] = {0.0};
float regN[WNITER * TN] = {0.0};

// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
float4 tmp = reinterpret_cast<float4 *>(
&A[(innerRowA + offset) * K + innerColA * 4])[0];
// transpose A while storing it
As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.w;
}

for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) {
reinterpret_cast<float4 *>(
&Bs[(innerRowB + offset) * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(
&B[(innerRowB + offset) * N + innerColB * 4])[0];
}
__syncthreads();

for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// populate registers for whole warptile
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint i = 0; i < TM; ++i) {
regM[wSubRowIdx * TM + i] =
As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM +
threadRowInWarp * TM + i];
}
}
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
for (uint i = 0; i < TN; ++i) {
regN[wSubColIdx * TN + i] =
Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN +
threadColInWarp * TN + i];
}
}

// execute warptile matmul
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
// calculate per-thread results
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
(wSubColIdx * TN) + resIdxN] +=
regM[wSubRowIdx * TM + resIdxM] *
regN[wSubColIdx * TN + resIdxN];
}
}
}
}
}
A += BK; // move BK columns to right
B += BK * N; // move BK rows down
__syncthreads();
}

// write out the results
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
// move C pointer to current warp subtile
float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN;
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
// load C vector into registers
float4 tmp = reinterpret_cast<float4 *>(
&C_interim[(threadRowInWarp * TM + resIdxM) * N +
threadColInWarp * TN + resIdxN])[0];
// perform GEMM update in reg
const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
wSubColIdx * TN + resIdxN;
tmp.x = alpha * threadResults[i + 0] + beta * tmp.x;
tmp.y = alpha * threadResults[i + 1] + beta * tmp.y;
tmp.z = alpha * threadResults[i + 2] + beta * tmp.z;
tmp.w = alpha * threadResults[i + 3] + beta * tmp.w;
// write back
reinterpret_cast<float4 *>(
&C_interim[(threadRowInWarp * TM + resIdxM) * N +
threadColInWarp * TN + resIdxN])[0] = tmp;
}
}
}
}
}
54 changes: 54 additions & 0 deletions src/runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,57 @@ void runSgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B,
K10_TN, K10_NUM_THREADS>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void runSgemmDoubleBuffering(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
const uint K11_NUM_THREADS = 128;
const uint K11_BN = 128;
const uint K11_BM = 64;
const uint K11_BK = 16;
const uint K11_WN = 64;
const uint K11_WM = 32;
const uint K11_WNITER = 1;
const uint K11_TN = 4;
const uint K11_TM = 4;
dim3 blockDim(K11_NUM_THREADS);

constexpr uint NUM_WARPS = K11_NUM_THREADS / 32;

// warptile in threadblocktile
static_assert((K11_BN % K11_WN == 0) and (K11_BM % K11_WM == 0));
static_assert((K11_BN / K11_WN) * (K11_BM / K11_WM) == NUM_WARPS);

// threads in warpsubtile
static_assert((K11_WM * K11_WN) % (WARPSIZE * K11_TM * K11_TN * K11_WNITER) ==
0);
constexpr uint K11_WMITER =
(K11_WM * K11_WN) / (32 * K11_TM * K11_TN * K11_WNITER);
// warpsubtile in warptile
static_assert((K11_WM % K11_WMITER == 0) and (K11_WN % K11_WNITER == 0));

static_assert((K11_NUM_THREADS * 4) % K11_BK == 0,
"NUM_THREADS*4 must be multiple of K9_BK to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of Bs during each iteraion)");
static_assert((K11_NUM_THREADS * 4) % K11_BN == 0,
"NUM_THREADS*4 must be multiple of K9_BN to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of As during each iteration)");
static_assert(K11_BN % (16 * K11_TN) == 0,
"BN must be a multiple of 16*TN to avoid quantization effects");
static_assert(K11_BM % (16 * K11_TM) == 0,
"BM must be a multiple of 16*TM to avoid quantization effects");
static_assert((K11_BM * K11_BK) % (4 * K11_NUM_THREADS) == 0,
"BM*BK must be a multiple of 4*256 to vectorize loads");
static_assert((K11_BN * K11_BK) % (4 * K11_NUM_THREADS) == 0,
"BN*BK must be a multiple of 4*256 to vectorize loads");

dim3 gridDim(CEIL_DIV(N, K11_BN), CEIL_DIV(M, K11_BM));
sgemmDoubleBuffering<K11_BM, K11_BN, K11_BK, K11_WM, K11_WN, K11_WNITER,
K11_TM, K11_TN, K11_NUM_THREADS>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C, cublasHandle_t handle) {
switch (kernel_num) {
Expand Down Expand Up @@ -406,6 +457,9 @@ void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
case 10:
runSgemmWarptiling(M, N, K, alpha, A, B, beta, C);
break;
case 11:
runSgemmDoubleBuffering(M, N, K, alpha, A, B, beta, C);
break;
default:
throw std::invalid_argument("Unknown kernel number");
}
Expand Down

0 comments on commit c990b12

Please sign in to comment.