Skip to content

Commit

Permalink
Add a different double buffering implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Boehm committed Mar 11, 2023
1 parent b64882f commit de3c245
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 19 deletions.
6 changes: 3 additions & 3 deletions sgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ const std::string errLogFile = "matrixValidationFailure.txt";

int main(int argc, char **argv) {
if (argc != 2) {
std::cerr << "Please select a kernel (range 0 - 11, 0 for NVIDIA cuBLAS)"
std::cerr << "Please select a kernel (range 0 - 12, 0 for NVIDIA cuBLAS)"
<< std::endl;
exit(EXIT_FAILURE);
}

// get kernel number
int kernel_num = std::stoi(argv[1]);
if (kernel_num < 0 || kernel_num > 11) {
std::cerr << "Please enter a valid kernel number (0-11)" << std::endl;
if (kernel_num < 0 || kernel_num > 12) {
std::cerr << "Please enter a valid kernel number (0-12)" << std::endl;
exit(EXIT_FAILURE);
}

Expand Down
1 change: 1 addition & 0 deletions src/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "kernels/10_kernel_warptiling.cuh"
#include "kernels/11_kernel_double_buffering.cuh"
#include "kernels/12_kernel_double_buffering.cuh"
#include "kernels/1_naive.cuh"
#include "kernels/2_kernel_global_mem_coalesce.cuh"
#include "kernels/3_kernel_shared_mem_blocking.cuh"
Expand Down
202 changes: 202 additions & 0 deletions src/kernels/12_kernel_double_buffering.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#pragma once

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

namespace {
template <const int BM, const int BN, const int BK, const int rowStrideA,
const int rowStrideB>
__device__ void loadFromGmem(int N, int K, float *A, float *B, float *As,
float *Bs, int innerRowA, int innerColA,
int innerRowB, int innerColB) {
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
float4 tmp = reinterpret_cast<float4 *>(
&A[(innerRowA + offset) * K + innerColA * 4])[0];
// transpose A while storing it
As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.w;
}

for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) {
reinterpret_cast<float4 *>(
&Bs[(innerRowB + offset) * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(
&B[(innerRowB + offset) * N + innerColB * 4])[0];
}
}

template <const int BM, const int BN, const int BK, const int WM, const int WN,
const int WMITER, const int WNITER, const int WSUBM, const int WSUBN,
const int TM, const int TN>
__device__ void
processFromSmem(float *regM, float *regN, float *threadResults, const float *As,
const float *Bs, const uint warpRow, const uint warpCol,
const uint threadRowInWarp, const uint threadColInWarp) {
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// populate registers for whole warptile
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint i = 0; i < TM; ++i) {
regM[wSubRowIdx * TM + i] =
As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM +
threadRowInWarp * TM + i];
}
}
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
for (uint i = 0; i < TN; ++i) {
regN[wSubColIdx * TN + i] =
Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN +
threadColInWarp * TN + i];
}
}

// execute warptile matmul
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
// calculate per-thread results
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
(wSubColIdx * TN) + resIdxN] +=
regM[wSubRowIdx * TM + resIdxM] *
regN[wSubColIdx * TN + resIdxN];
}
}
}
}
}
}

} // namespace

/*
* @tparam BM The threadblock size for M dimension SMEM caching.
* @tparam BN The threadblock size for N dimension SMEM caching.
* @tparam BK The threadblock size for K dimension SMEM caching.
* @tparam WM M dim of continuous tile computed by each warp
* @tparam WN N dim of continuous tile computed by each warp
* @tparam WMITER The number of subwarp tiling steps in M dimension.
* @tparam WNITER The number of subwarp tiling steps in N dimension.
* @tparam TM The per-thread tile size for M dimension.
* @tparam TN The per-thread tile size for N dimension.
*/
template <const int BM, const int BN, const int BK, const int WM, const int WN,
const int WNITER, const int TM, const int TN, const int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
runSgemmDoubleBuffering2(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;

// Placement of the warp in the threadblock tile
const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in
const uint warpCol = warpIdx % (BN / WN);
const uint warpRow = warpIdx / (BN / WN);

// size of the warp subtile
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
constexpr uint WSUBM = WM / WMITER; // 64/2=32
constexpr uint WSUBN = WN / WNITER; // 32/2=16

// Placement of the thread in the warp subtile
const uint threadIdxInWarp = threadIdx.x % WARPSIZE; // [0, 31]
const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4)
const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4

// allocate space for the current blocktile in SMEM
__shared__ float As[2 * BM * BK];
__shared__ float Bs[2 * BK * BN];

// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
// Move C_ptr to warp's output tile
C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN;

// calculating the indices that this thread will load into SMEM
// we'll load 128bit / 32bit = 4 elements per thread at each step
const uint innerRowA = threadIdx.x / (BK / 4);
const uint innerColA = threadIdx.x % (BK / 4);
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
const uint innerRowB = threadIdx.x / (BN / 4);
const uint innerColB = threadIdx.x % (BN / 4);
constexpr uint rowStrideB = NUM_THREADS / (BN / 4);

// allocate thread-local cache for results in registerfile
float threadResults[WMITER * TM * WNITER * TN] = {0.0};
// we cache into registers on the warptile level
float regM[WMITER * TM] = {0.0};
float regN[WNITER * TN] = {0.0};

int As_offset = 0;
int Bs_offset = 0;

// double-buffering: load first blocktile into SMEM
loadFromGmem<BM, BN, BK, rowStrideA, rowStrideB>(
N, K, A, B, As + As_offset * BM * BK, Bs + Bs_offset * BK * BN, innerRowA,
innerColA, innerRowB, innerColB);

__syncthreads();

// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K - BK; bkIdx += BK) {
// double-buffering: load next blocktile into SMEM
loadFromGmem<BM, BN, BK, rowStrideA, rowStrideB>(
N, K, A + BK, B + BK * N, As + (1 - As_offset) * BM * BK,
Bs + (1 - Bs_offset) * BK * BN, innerRowA, innerColA, innerRowB,
innerColB);

// compute the current blocktile
processFromSmem<BM, BN, BK, WM, WN, WMITER, WNITER, WSUBM, WSUBN, TM, TN>(
regM, regN, threadResults, As + As_offset * BM * BK,
Bs + Bs_offset * BK * BN, warpRow, warpCol, threadRowInWarp,
threadColInWarp);
A += BK; // move BK columns to right
B += BK * N; // move BK rows down

__syncthreads();

As_offset = 1 - As_offset;
Bs_offset = 1 - Bs_offset;
}

// compute the last blocktile
processFromSmem<BM, BN, BK, WM, WN, WMITER, WNITER, WSUBM, WSUBN, TM, TN>(
regM, regN, threadResults, As + As_offset * BM * BK,
Bs + Bs_offset * BK * BN, warpRow, warpCol, threadRowInWarp,
threadColInWarp);

// write out the results
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) {
// move C pointer to current warp subtile
float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN;
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
// load C vector into registers
float4 tmp = reinterpret_cast<float4 *>(
&C_interim[(threadRowInWarp * TM + resIdxM) * N +
threadColInWarp * TN + resIdxN])[0];
// perform GEMM update in reg
const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
wSubColIdx * TN + resIdxN;
tmp.x = alpha * threadResults[i + 0] + beta * tmp.x;
tmp.y = alpha * threadResults[i + 1] + beta * tmp.y;
tmp.z = alpha * threadResults[i + 2] + beta * tmp.z;
tmp.w = alpha * threadResults[i + 3] + beta * tmp.w;
// write back
reinterpret_cast<float4 *>(
&C_interim[(threadRowInWarp * TM + resIdxM) * N +
threadColInWarp * TN + resIdxN])[0] = tmp;
}
}
}
}
}
86 changes: 70 additions & 16 deletions src/runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -392,25 +392,25 @@ void runSgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B,
void runSgemmDoubleBuffering(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
// Settings for A100
const uint K11_NUM_THREADS = 256;
const uint K11_BN = 128;
const uint K11_BM = 64;
const uint K11_BK = 16;
const uint K11_WN = 32;
const uint K11_WM = 32;
const uint K11_WNITER = 2;
const uint K11_TN = 4;
const uint K11_TM = 4;
// Settings for A6000
// const uint K11_NUM_THREADS = 256;
// const uint K11_BN = 256;
// const uint K11_BM = 128;
// const uint K11_BN = 128;
// const uint K11_BM = 64;
// const uint K11_BK = 16;
// const uint K11_WN = 32;
// const uint K11_WM = 128;
// const uint K11_WNITER = 1;
// const uint K11_TN = 8;
// const uint K11_TM = 8;
// const uint K11_WM = 32;
// const uint K11_WNITER = 2;
// const uint K11_TN = 4;
// const uint K11_TM = 4;
// Settings for A6000
const uint K11_NUM_THREADS = 256;
const uint K11_BN = 256;
const uint K11_BM = 128;
const uint K11_BK = 16;
const uint K11_WN = 32;
const uint K11_WM = 128;
const uint K11_WNITER = 1;
const uint K11_TN = 8;
const uint K11_TM = 8;
dim3 blockDim(K11_NUM_THREADS);

constexpr uint NUM_WARPS = K11_NUM_THREADS / 32;
Expand Down Expand Up @@ -450,6 +450,57 @@ void runSgemmDoubleBuffering(int M, int N, int K, float alpha, float *A,
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void runSgemmDoubleBuffering2(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
// Settings for A6000
const uint K12_NUM_THREADS = 128;
const uint K12_BN = 128;
const uint K12_BM = 128;
const uint K12_BK = 16;
const uint K12_WN = 64;
const uint K12_WM = 64;
const uint K12_WNITER = 4;
const uint K12_TN = 4;
const uint K12_TM = 8;
dim3 blockDim(K12_NUM_THREADS);

constexpr uint NUM_WARPS = K12_NUM_THREADS / 32;

// warptile in threadblocktile
static_assert((K12_BN % K12_WN == 0) and (K12_BM % K12_WM == 0));
static_assert((K12_BN / K12_WN) * (K12_BM / K12_WM) == NUM_WARPS);

// threads in warpsubtile
static_assert((K12_WM * K12_WN) % (WARPSIZE * K12_TM * K12_TN * K12_WNITER) ==
0);
constexpr uint K12_WMITER =
(K12_WM * K12_WN) / (32 * K12_TM * K12_TN * K12_WNITER);
// warpsubtile in warptile
static_assert((K12_WM % K12_WMITER == 0) and (K12_WN % K12_WNITER == 0));

static_assert((K12_NUM_THREADS * 4) % K12_BK == 0,
"NUM_THREADS*4 must be multiple of K9_BK to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of Bs during each iteraion)");
static_assert((K12_NUM_THREADS * 4) % K12_BN == 0,
"NUM_THREADS*4 must be multiple of K9_BN to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of As during each iteration)");
static_assert(K12_BN % (16 * K12_TN) == 0,
"BN must be a multiple of 16*TN to avoid quantization effects");
static_assert(K12_BM % (16 * K12_TM) == 0,
"BM must be a multiple of 16*TM to avoid quantization effects");
static_assert((K12_BM * K12_BK) % (4 * K12_NUM_THREADS) == 0,
"BM*BK must be a multiple of 4*256 to vectorize loads");
static_assert((K12_BN * K12_BK) % (4 * K12_NUM_THREADS) == 0,
"BN*BK must be a multiple of 4*256 to vectorize loads");

dim3 gridDim(CEIL_DIV(N, K12_BN), CEIL_DIV(M, K12_BM));
runSgemmDoubleBuffering2<K12_BM, K12_BN, K12_BK, K12_WM, K12_WN, K12_WNITER,
K12_TM, K12_TN, K12_NUM_THREADS>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C, cublasHandle_t handle) {
switch (kernel_num) {
Expand Down Expand Up @@ -489,6 +540,9 @@ void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
case 11:
runSgemmDoubleBuffering(M, N, K, alpha, A, B, beta, C);
break;
case 12:
runSgemmDoubleBuffering2(M, N, K, alpha, A, B, beta, C);
break;
default:
throw std::invalid_argument("Unknown kernel number");
}
Expand Down

0 comments on commit de3c245

Please sign in to comment.