Skip to content

Commit

Permalink
Add warptiling kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Boehm committed Feb 20, 2023
1 parent 540fddb commit c60e244
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 17 deletions.
2 changes: 1 addition & 1 deletion sgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ int main(int argc, char **argv) {
cudaCheck(cudaMemcpy(dC_ref, C, sizeof(float) * max_size * max_size,
cudaMemcpyHostToDevice));

int repeat_times = 500;
int repeat_times = 50;
for (int size : SIZE) {
m = n = k = size;

Expand Down
5 changes: 3 additions & 2 deletions src/kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "kernels/10_kernel_warptiling.cuh"
#include "kernels/1_naive.cuh"
#include "kernels/2_kernel_global_mem_coalesce.cuh"
#include "kernels/3_kernel_shared_mem_blocking.cuh"
#include "kernels/4_kernel_1D_warptiling.cuh"
#include "kernels/5_kernel_2D_warptiling.cuh"
#include "kernels/4_kernel_1D_blocktiling.cuh"
#include "kernels/5_kernel_2D_blocktiling.cuh"
#include "kernels/6_kernel_vectorize.cuh"
#include "kernels/7_kernel_resolve_bank_conflicts.cuh"
#include "kernels/8_kernel_bank_extra_col.cuh"
Expand Down
134 changes: 134 additions & 0 deletions src/kernels/10_kernel_warptiling.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#pragma once

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)
const int K10_NUM_THREADS = 256;

template <const int BM, const int BN, const int BK, const int TBM,
const int TBN, const int WM, const int WN, const int TM, const int TN>
__global__ void __launch_bounds__(K10_NUM_THREADS)
sgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;

// iterations of threadblock tile
constexpr int TBMITER = CEIL_DIV(BM, TBM);
constexpr int TBNITER = CEIL_DIV(BN, TBN);

// Placement of the warp in the threadblock tile
const uint warpIdx = threadIdx.x / warpSize; // the warp this thread is in
const uint warpCol = warpIdx % (TBN / WN);
const uint warpRow = warpIdx / (TBN / WN);

// Placement of the thread in the warp tile
const uint threadIdxInWarp = threadIdx.x % warpSize; // [0, 31]
const uint threadColInWarp = threadIdxInWarp % (WN / TN);
const uint threadRowInWarp = threadIdxInWarp / (WN / TN);

// allocate space for the current blocktile in SMEM
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];

// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;

// calculating the indices that this thread will load into SMEM
// we'll load 128bit / 32bit = 4 elements per thread at each step
const uint innerRowA = threadIdx.x / (BK / 4);
const uint innerColA = threadIdx.x % (BK / 4);
constexpr uint rowStrideA = (K10_NUM_THREADS * 4) / BK;
const uint innerRowB = threadIdx.x / (BN / 4);
const uint innerColB = threadIdx.x % (BN / 4);
constexpr uint rowStrideB = K10_NUM_THREADS / (BN / 4);

// allocate thread-local cache for results in registerfile
float threadResults[TBMITER * TM * TBNITER * TN] = {0.0};
float regM[TM] = {0.0};
float regN[TN] = {0.0};

// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
if (bkIdx != 0) {
__syncthreads();
}
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
float4 tmp = reinterpret_cast<float4 *>(
&A[(innerRowA + offset) * K + innerColA * 4])[0];
// transpose A while storing it
As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.w;
}

for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) {
reinterpret_cast<float4 *>(
&Bs[(innerRowB + offset) * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(
&B[(innerRowB + offset) * N + innerColB * 4])[0];
}
__syncthreads();

for (uint tbmIdx = 0; tbmIdx < TBMITER; ++tbmIdx) {
for (uint tbnIdx = 0; tbnIdx < TBNITER; ++tbnIdx) {
// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// block into registers
for (uint i = 0; i < TM; ++i) {
regM[i] = As[dotIdx * BM + (tbmIdx * TBM) + warpRow * WM +
threadRowInWarp * TM + i];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + (tbnIdx * TBN) + warpCol * WN +
threadColInWarp * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[(tbmIdx * TM + resIdxM) * (TBNITER * TN) +
tbnIdx * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
}
}
A += BK; // move BK columns to right
B += BK * N; // move BK rows down
}

// write out the results
for (uint tbmIdx = 0; tbmIdx < TBMITER; ++tbmIdx) {
for (uint tbnIdx = 0; tbnIdx < TBNITER; ++tbnIdx) {
float *C_interim = C + (tbmIdx * TBM * N) + (tbnIdx * TBN);
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
// load C vector into registers
float4 tmp = reinterpret_cast<float4 *>(
&C_interim[(warpRow * WM + threadRowInWarp * TM + resIdxM) * N +
warpCol * WN + threadColInWarp * TN + resIdxN])[0];
// perform GEMM update in reg
const int i =
(tbmIdx * TM + resIdxM) * (TBNITER * TN) + tbnIdx * TN + resIdxN;
tmp.x = alpha * threadResults[i + 0] + beta * tmp.x;
tmp.y = alpha * threadResults[i + 1] + beta * tmp.y;
tmp.z = alpha * threadResults[i + 2] + beta * tmp.z;
tmp.w = alpha * threadResults[i + 3] + beta * tmp.w;
// write back
reinterpret_cast<float4 *>(
&C_interim[(warpRow * WM + threadRowInWarp * TM + resIdxM) * N +
warpCol * WM + threadColInWarp * TN + resIdxN])[0] =
tmp;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)

template <const int BM, const int BN, const int BK, const int TM>
__global__ void sgemm1DWarpTiling(int M, int N, int K, float alpha,
const float *A, const float *B, float beta,
float *C) {
__global__ void sgemm1DBlocktiling(int M, int N, int K, float alpha,
const float *A, const float *B, float beta,
float *C) {
// If we flip x and y here we get ~30% less performance for large matrices.
// The current, 30% faster configuration ensures that blocks with sequential
// blockIDs access columns of B sequentially, while sharing the same row of A.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

template <const int BM, const int BN, const int BK, const int TM, const int TN>
__global__ void __launch_bounds__((BM * BN) / (TM * TN), 1)
sgemm2DWarpTiling(int M, int N, int K, float alpha, const float *A,
const float *B, float beta, float *C) {
sgemm2DBlocktiling(int M, int N, int K, float alpha, const float *A,
const float *B, float beta, float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;

Expand Down
63 changes: 54 additions & 9 deletions src/runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,20 @@ void run_sgemm_shared_mem_block(int M, int N, int K, float alpha, float *A,
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void runSgemm1DWarpTiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
void runSgemm1DBlocktiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint BM = 64;
const uint BN = 64;
const uint BK = 8;
const uint TM = 8;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / TM);
sgemm1DWarpTiling<BM, BN, BK, TM>
sgemm1DBlocktiling<BM, BN, BK, TM>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void runSgemm2DWarpTiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
void runSgemm2DBlocktiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint BK = 8;
const uint TM = 8;
const uint TN = 8;
Expand All @@ -200,7 +200,7 @@ void runSgemm2DWarpTiling(int M, int N, int K, float alpha, float *A, float *B,
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemm2DWarpTiling<BM, BN, BK, TM, TN>
sgemm2DBlocktiling<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
Expand All @@ -209,7 +209,7 @@ void runSgemm2DWarpTiling(int M, int N, int K, float alpha, float *A, float *B,
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemm2DWarpTiling<BM, BN, BK, TM, TN>
sgemm2DBlocktiling<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}
Expand Down Expand Up @@ -321,6 +321,48 @@ void runSgemmAutotuned(int M, int N, int K, float alpha, float *A, float *B,
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}

void runSgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint K10_NUM_THREADS = 256;
const uint K10_BN = 128;
const uint K10_BM = 128;
const uint K10_BK = 16;
const uint K10_TBN = 64;
const uint K10_TBM = 64;
const uint K10_WN = 16;
const uint K10_WM = 32;
const uint K10_TN = 4;
const uint K10_TM = 4;
dim3 blockDim(K10_NUM_THREADS);

static_assert(
(K10_NUM_THREADS * 4) % K10_BK == 0,
"NUM_THREADS*4 must be multiple of K9_BK to avoid quantization issues "
"during GMEM->SMEM tiling (loading only parts of the final row of Bs "
"during each iteraion)");
static_assert(
(K10_NUM_THREADS * 4) % K10_BN == 0,
"NUM_THREADS*4 must be multiple of K9_BN to avoid quantization issues "
"during GMEM->SMEM tiling (loading only parts of the final row of As "
"during each iteration)");
static_assert(K10_BN % (16 * K10_TN) == 0,
"BN must be a multiple of 16*TN to "
"avoid quantization effects");
static_assert(K10_BM % (16 * K10_TM) == 0,
"BM must be a multiple of 16*TM to "
"avoid quantization effects");
static_assert((K10_BM * K10_BK) % (4 * K10_NUM_THREADS) == 0,
"BM*BK must be a multiple of 4*256 to "
"vectorize loads");
static_assert((K10_BN * K10_BK) % (4 * K10_NUM_THREADS) == 0,
"BN*BK must be a multiple of 4*256 to "
"vectorize loads");

dim3 gridDim(CEIL_DIV(N, K10_BN), CEIL_DIV(M, K10_BM));
sgemmWarptiling<K10_BM, K10_BN, K10_BK, K10_TBM, K10_TBN, K10_WM, K10_WN,
K10_TM, K10_TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C, cublasHandle_t handle) {
switch (kernel_num) {
Expand All @@ -337,10 +379,10 @@ void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
run_sgemm_shared_mem_block(M, N, K, alpha, A, B, beta, C);
break;
case 4:
runSgemm1DWarpTiling(M, N, K, alpha, A, B, beta, C);
runSgemm1DBlocktiling(M, N, K, alpha, A, B, beta, C);
break;
case 5:
runSgemm2DWarpTiling(M, N, K, alpha, A, B, beta, C);
runSgemm2DBlocktiling(M, N, K, alpha, A, B, beta, C);
break;
case 6:
runSgemmVectorize(M, N, K, alpha, A, B, beta, C);
Expand All @@ -354,6 +396,9 @@ void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
case 9:
runSgemmAutotuned(M, N, K, alpha, A, B, beta, C);
break;
case 10:
runSgemmWarptiling(M, N, K, alpha, A, B, beta, C);
break;
default:
throw std::invalid_argument("Unknown kernel number");
}
Expand Down

0 comments on commit c60e244

Please sign in to comment.