From fb0a14dc79d4c86f952aa3d549c362ca704fdf36 Mon Sep 17 00:00:00 2001 From: Simon Boehm Date: Sun, 29 Jan 2023 21:42:05 +0000 Subject: [PATCH] Add failed attempt at resolving bank conflicts --- plot_benchmark_results.py | 5 +- scripts/bank_calc.py | 33 +++++++ src/kernels.cuh | 4 +- src/kernels/8_kernel_bank_extra_col.cuh | 117 ++++++++++++++++++++++++ src/kernels/9_kernel_generalizing.cuh | 107 ++++++++++++++++++++++ src/runner.cu | 54 +++++++++++ 6 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 scripts/bank_calc.py create mode 100644 src/kernels/8_kernel_bank_extra_col.cuh create mode 100644 src/kernels/9_kernel_generalizing.cuh diff --git a/plot_benchmark_results.py b/plot_benchmark_results.py index 84deda0..47c3e38 100755 --- a/plot_benchmark_results.py +++ b/plot_benchmark_results.py @@ -26,7 +26,10 @@ 3: "SMEM Caching", 4: "1D Warptiling", 5: "2D Warptiling", - 6: "Vectorized Mem access", + 6: "Vectorized Mem Access", + 7: "Avoid Bank Conflicts", + 8: "Double Buffering", + 9: "Autotuning", } diff --git a/scripts/bank_calc.py b/scripts/bank_calc.py new file mode 100644 index 0000000..f9f8d03 --- /dev/null +++ b/scripts/bank_calc.py @@ -0,0 +1,33 @@ +banks_naive = lambda r, c: (r * 32 + c) % 32 +banks_one_extra = lambda r, c: (r * 33 + c) % 32 + +ITEMS_PER_WARP = 8 + + +def printBankConflicts(bank_fun): + for c in range(1): + banks = [] + for i in range(32): + row = (i * ITEMS_PER_WARP) // 16 + col = (i * ITEMS_PER_WARP + c) % 16 + banks.append((i, row, col, bank_fun(row, col))) + print("Step", c, "\n", "\n".join(["(" + ",".join(str(x) for x in i) + ")" for i in banks])) + d = {k: 0 for k in range(32)} + for i in banks: + d[i[-1]] += 1 + + count = 0 + for key, val in d.items(): + if val > 0: + count += 1 + + print( + f"Bank conflicts (Step {c}): {sorted(d.items(), key=lambda item: item[1], reverse=True)[0][1]}, banks accessed: {count}/32\n" + ) + + +print("---NAIVE---") +printBankConflicts(banks_naive, 32) + +print("\n---EXTRA COL---") +printBankConflicts(banks_one_extra, 33) diff --git a/src/kernels.cuh b/src/kernels.cuh index f9221f0..e6a987d 100644 --- a/src/kernels.cuh +++ b/src/kernels.cuh @@ -6,4 +6,6 @@ #include "kernels/4_kernel_1D_warptiling.cuh" #include "kernels/5_kernel_2D_warptiling.cuh" #include "kernels/6_kernel_vectorize.cuh" -#include "kernels/7_kernel_resolve_bank_conflicts.cuh" \ No newline at end of file +#include "kernels/7_kernel_resolve_bank_conflicts.cuh" +#include "kernels/8_kernel_bank_extra_col.cuh" +#include "kernels/9_kernel_generalizing.cuh" \ No newline at end of file diff --git a/src/kernels/8_kernel_bank_extra_col.cuh b/src/kernels/8_kernel_bank_extra_col.cuh new file mode 100644 index 0000000..b22085a --- /dev/null +++ b/src/kernels/8_kernel_bank_extra_col.cuh @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) ((M) + (N)-1) / (N) + +template +__global__ void sgemmResolveBankExtraCol(int M, int N, int K, float alpha, + float *A, float *B, float beta, + float *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + const uint totalResultsBlocktile = BM * BN; + // A thread is responsible for calculating TM*TN elements in the blocktile + const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN); + + // ResultsPerBlock / ResultsPerThread == ThreadsPerBlock + assert(numThreadsBlocktile == blockDim.x); + + // BN/TN are the number of threads to span a column + const int threadCol = threadIdx.x % (BN / TN); + const int threadRow = threadIdx.x / (BN / TN); + + // allocate space for the current blocktile in smem + __shared__ float As[BM * BK]; + const int extraCols = 5; + __shared__ float Bs[BK * (BN + extraCols)]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + // calculates the number of rows of As that are being loaded in a single step + // by a single block + const uint rowStrideA = (numThreadsBlocktile * 4) / BK; + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + // for both As and Bs we want each load to span the full column-width, for + // better GMEM coalescing (as opposed to spanning full row-width and iterating + // across columns) + const uint rowStrideB = numThreadsBlocktile / (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[TM * TN] = {0.0}; + float regM[TM] = {0.0}; + float regN[TN] = {0.0}; + + // outer-most loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + // transpose A while loading it + float4 tmp = + reinterpret_cast(&A[innerRowA * K + innerColA * 4])[0]; + As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x; + As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y; + As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z; + As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w; + + tmp = reinterpret_cast(&B[innerRowB * N + innerColB * 4])[0]; + Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 0] = tmp.x; + Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 1] = tmp.y; + Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 2] = tmp.z; + Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 3] = tmp.w; + __syncthreads(); + + // advance blocktile + A += BK; // move BK columns to right + B += BK * N; // move BK rows down + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // block into registers + for (uint i = 0; i < TM; ++i) { + regM[i] = As[dotIdx * BM + threadRow * TM + i]; + } + for (uint i = 0; i < TN; ++i) { + regN[i] = Bs[dotIdx * (BN + extraCols) + threadCol * TN + i]; + } + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[resIdxM * TN + resIdxN] += + regM[resIdxM] * regN[resIdxN]; + } + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { + // load C vector into registers + float4 tmp = reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; + // perform GEMM update in reg + tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; + tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = + tmp; + } + } +} \ No newline at end of file diff --git a/src/kernels/9_kernel_generalizing.cuh b/src/kernels/9_kernel_generalizing.cuh new file mode 100644 index 0000000..e6c98b3 --- /dev/null +++ b/src/kernels/9_kernel_generalizing.cuh @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) ((M) + (N)-1) / (N) + +template +__global__ void sgemmGeneralize(int M, int N, int K, float alpha, float *A, + float *B, float beta, float *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + const uint totalResultsBlocktile = BM * BN; + // A thread is responsible for calculating TM*TN elements in the blocktile + const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN); + + // ResultsPerBlock / ResultsPerThread == ThreadsPerBlock + assert(numThreadsBlocktile == blockDim.x); + + // BN/TN are the number of threads to span a column + const int threadCol = threadIdx.x % (BN / TN); + const int threadRow = threadIdx.x / (BN / TN); + + // allocate space for the current blocktile in smem + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + const uint rowStrideA = (numThreadsBlocktile * 4) / BK; + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + const uint rowStrideB = numThreadsBlocktile / (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[TM * TN] = {0.0}; + float regM[TM] = {0.0}; + float regN[TN] = {0.0}; + + // outer-most loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + // transpose A while loading it + float4 tmp = + reinterpret_cast(&A[innerRowA * K + innerColA * 4])[0]; + As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x; + As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y; + As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z; + As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w; + + reinterpret_cast(&Bs[innerRowB * BN + innerColB * 4])[0] = + reinterpret_cast(&B[innerRowB * N + innerColB * 4])[0]; + __syncthreads(); + + // advance blocktile + A += BK; // move BK columns to right + B += BK * N; // move BK rows down + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // block into registers + for (uint i = 0; i < TM; ++i) { + regM[i] = As[dotIdx * BM + threadRow * TM + i]; + } + for (uint i = 0; i < TN; ++i) { + regN[i] = Bs[dotIdx * BN + threadCol * TN + i]; + } + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[resIdxM * TN + resIdxN] += + regM[resIdxM] * regN[resIdxN]; + } + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { + // load C vector into registers + float4 tmp = reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; + // perform GEMM update in reg + tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; + tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = + tmp; + } + } +} \ No newline at end of file diff --git a/src/runner.cu b/src/runner.cu index 30f7c19..7026797 100644 --- a/src/runner.cu +++ b/src/runner.cu @@ -262,6 +262,54 @@ void runSgemmResolveBankConflicts(int M, int N, int K, float alpha, float *A, } } +void runSgemmResolveBankExtraCol(int M, int N, int K, float alpha, float *A, + float *B, float beta, float *C) { + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 and N >= 128) { + const uint BM = 128; + const uint BN = 128; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmResolveBankExtraCol + <<>>(M, N, K, alpha, A, B, beta, C); + } else { + // this is a hacky solution to the underlying problem + // of not having proper bounds checking in the kernel + const uint BM = 64; + const uint BN = 64; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmResolveBankExtraCol + <<>>(M, N, K, alpha, A, B, beta, C); + } +} + +void runSgemmGeneralize(int M, int N, int K, float alpha, float *A, float *B, + float beta, float *C) { + const uint BK = 16; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 and N >= 128) { + const uint BM = 128; + const uint BN = 128; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmGeneralize + <<>>(M, N, K, alpha, A, B, beta, C); + } else { + // this is a hacky solution to the underlying problem + // of not having proper bounds checking in the kernel + const uint BM = 64; + const uint BN = 64; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmGeneralize + <<>>(M, N, K, alpha, A, B, beta, C); + } +} + void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A, float *B, float beta, float *C, cublasHandle_t handle) { switch (kernel_num) { @@ -289,6 +337,12 @@ void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A, case 7: runSgemmResolveBankConflicts(M, N, K, alpha, A, B, beta, C); break; + case 8: + runSgemmResolveBankExtraCol(M, N, K, alpha, A, B, beta, C); + break; + case 9: + runSgemmGeneralize(M, N, K, alpha, A, B, beta, C); + break; default: throw std::invalid_argument("Unknown kernel number"); }