Skip to content

Commit

Permalink
Add failed attempt at resolving bank conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Boehm committed Jan 29, 2023
1 parent 33322ce commit fb0a14d
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 2 deletions.
5 changes: 4 additions & 1 deletion plot_benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
3: "SMEM Caching",
4: "1D Warptiling",
5: "2D Warptiling",
6: "Vectorized Mem access",
6: "Vectorized Mem Access",
7: "Avoid Bank Conflicts",
8: "Double Buffering",
9: "Autotuning",
}


Expand Down
33 changes: 33 additions & 0 deletions scripts/bank_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
banks_naive = lambda r, c: (r * 32 + c) % 32
banks_one_extra = lambda r, c: (r * 33 + c) % 32

ITEMS_PER_WARP = 8


def printBankConflicts(bank_fun):
for c in range(1):
banks = []
for i in range(32):
row = (i * ITEMS_PER_WARP) // 16
col = (i * ITEMS_PER_WARP + c) % 16
banks.append((i, row, col, bank_fun(row, col)))
print("Step", c, "\n", "\n".join(["(" + ",".join(str(x) for x in i) + ")" for i in banks]))
d = {k: 0 for k in range(32)}
for i in banks:
d[i[-1]] += 1

count = 0
for key, val in d.items():
if val > 0:
count += 1

print(
f"Bank conflicts (Step {c}): {sorted(d.items(), key=lambda item: item[1], reverse=True)[0][1]}, banks accessed: {count}/32\n"
)


print("---NAIVE---")
printBankConflicts(banks_naive, 32)

print("\n---EXTRA COL---")
printBankConflicts(banks_one_extra, 33)
4 changes: 3 additions & 1 deletion src/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
#include "kernels/4_kernel_1D_warptiling.cuh"
#include "kernels/5_kernel_2D_warptiling.cuh"
#include "kernels/6_kernel_vectorize.cuh"
#include "kernels/7_kernel_resolve_bank_conflicts.cuh"
#include "kernels/7_kernel_resolve_bank_conflicts.cuh"
#include "kernels/8_kernel_bank_extra_col.cuh"
#include "kernels/9_kernel_generalizing.cuh"
117 changes: 117 additions & 0 deletions src/kernels/8_kernel_bank_extra_col.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#pragma once

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)

template <const int BM, const int BN, const int BK, const int TM, const int TN>
__global__ void sgemmResolveBankExtraCol(int M, int N, int K, float alpha,
float *A, float *B, float beta,
float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;

const uint totalResultsBlocktile = BM * BN;
// A thread is responsible for calculating TM*TN elements in the blocktile
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);

// ResultsPerBlock / ResultsPerThread == ThreadsPerBlock
assert(numThreadsBlocktile == blockDim.x);

// BN/TN are the number of threads to span a column
const int threadCol = threadIdx.x % (BN / TN);
const int threadRow = threadIdx.x / (BN / TN);

// allocate space for the current blocktile in smem
__shared__ float As[BM * BK];
const int extraCols = 5;
__shared__ float Bs[BK * (BN + extraCols)];

// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;

// calculating the indices that this thread will load into SMEM
// we'll load 128bit / 32bit = 4 elements per thread at each step
const uint innerRowA = threadIdx.x / (BK / 4);
const uint innerColA = threadIdx.x % (BK / 4);
// calculates the number of rows of As that are being loaded in a single step
// by a single block
const uint rowStrideA = (numThreadsBlocktile * 4) / BK;
const uint innerRowB = threadIdx.x / (BN / 4);
const uint innerColB = threadIdx.x % (BN / 4);
// for both As and Bs we want each load to span the full column-width, for
// better GMEM coalescing (as opposed to spanning full row-width and iterating
// across columns)
const uint rowStrideB = numThreadsBlocktile / (BN / 4);

// allocate thread-local cache for results in registerfile
float threadResults[TM * TN] = {0.0};
float regM[TM] = {0.0};
float regN[TN] = {0.0};

// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// populate the SMEM caches
// transpose A while loading it
float4 tmp =
reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0];
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w;

tmp = reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 0] = tmp.x;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 1] = tmp.y;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 2] = tmp.z;
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 3] = tmp.w;
__syncthreads();

// advance blocktile
A += BK; // move BK columns to right
B += BK * N; // move BK rows down

// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// block into registers
for (uint i = 0; i < TM; ++i) {
regM[i] = As[dotIdx * BM + threadRow * TM + i];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * (BN + extraCols) + threadCol * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}

// write out the results
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
// load C vector into registers
float4 tmp = reinterpret_cast<float4 *>(
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0];
// perform GEMM update in reg
tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x;
tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y;
tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z;
tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w;
// write back
reinterpret_cast<float4 *>(
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] =
tmp;
}
}
}
107 changes: 107 additions & 0 deletions src/kernels/9_kernel_generalizing.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#pragma once

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)

template <const int BM, const int BN, const int BK, const int TM, const int TN>
__global__ void sgemmGeneralize(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;

const uint totalResultsBlocktile = BM * BN;
// A thread is responsible for calculating TM*TN elements in the blocktile
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);

// ResultsPerBlock / ResultsPerThread == ThreadsPerBlock
assert(numThreadsBlocktile == blockDim.x);

// BN/TN are the number of threads to span a column
const int threadCol = threadIdx.x % (BN / TN);
const int threadRow = threadIdx.x / (BN / TN);

// allocate space for the current blocktile in smem
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];

// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;

// calculating the indices that this thread will load into SMEM
// we'll load 128bit / 32bit = 4 elements per thread at each step
const uint innerRowA = threadIdx.x / (BK / 4);
const uint innerColA = threadIdx.x % (BK / 4);
const uint rowStrideA = (numThreadsBlocktile * 4) / BK;
const uint innerRowB = threadIdx.x / (BN / 4);
const uint innerColB = threadIdx.x % (BN / 4);
const uint rowStrideB = numThreadsBlocktile / (BN / 4);

// allocate thread-local cache for results in registerfile
float threadResults[TM * TN] = {0.0};
float regM[TM] = {0.0};
float regN[TN] = {0.0};

// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// populate the SMEM caches
// transpose A while loading it
float4 tmp =
reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0];
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z;
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w;

reinterpret_cast<float4 *>(&Bs[innerRowB * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0];
__syncthreads();

// advance blocktile
A += BK; // move BK columns to right
B += BK * N; // move BK rows down

// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// block into registers
for (uint i = 0; i < TM; ++i) {
regM[i] = As[dotIdx * BM + threadRow * TM + i];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}

// write out the results
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) {
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) {
// load C vector into registers
float4 tmp = reinterpret_cast<float4 *>(
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0];
// perform GEMM update in reg
tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x;
tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y;
tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z;
tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w;
// write back
reinterpret_cast<float4 *>(
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] =
tmp;
}
}
}
54 changes: 54 additions & 0 deletions src/runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,54 @@ void runSgemmResolveBankConflicts(int M, int N, int K, float alpha, float *A,
}
}

void runSgemmResolveBankExtraCol(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
const uint BK = 8;
const uint TM = 8;
const uint TN = 8;
if (M >= 128 and N >= 128) {
const uint BM = 128;
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmResolveBankExtraCol<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
// of not having proper bounds checking in the kernel
const uint BM = 64;
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmResolveBankExtraCol<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}

void runSgemmGeneralize(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint BK = 16;
const uint TM = 8;
const uint TN = 8;
if (M >= 128 and N >= 128) {
const uint BM = 128;
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmGeneralize<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
// of not having proper bounds checking in the kernel
const uint BM = 64;
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmGeneralize<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}

void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C, cublasHandle_t handle) {
switch (kernel_num) {
Expand Down Expand Up @@ -289,6 +337,12 @@ void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
case 7:
runSgemmResolveBankConflicts(M, N, K, alpha, A, B, beta, C);
break;
case 8:
runSgemmResolveBankExtraCol(M, N, K, alpha, A, B, beta, C);
break;
case 9:
runSgemmGeneralize(M, N, K, alpha, A, B, beta, C);
break;
default:
throw std::invalid_argument("Unknown kernel number");
}
Expand Down

0 comments on commit fb0a14d

Please sign in to comment.