forked from siboehm/SGEMM_CUDA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add failed attempt at resolving bank conflicts
- Loading branch information
Simon Boehm
committed
Jan 29, 2023
1 parent
33322ce
commit fb0a14d
Showing
6 changed files
with
318 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
banks_naive = lambda r, c: (r * 32 + c) % 32 | ||
banks_one_extra = lambda r, c: (r * 33 + c) % 32 | ||
|
||
ITEMS_PER_WARP = 8 | ||
|
||
|
||
def printBankConflicts(bank_fun): | ||
for c in range(1): | ||
banks = [] | ||
for i in range(32): | ||
row = (i * ITEMS_PER_WARP) // 16 | ||
col = (i * ITEMS_PER_WARP + c) % 16 | ||
banks.append((i, row, col, bank_fun(row, col))) | ||
print("Step", c, "\n", "\n".join(["(" + ",".join(str(x) for x in i) + ")" for i in banks])) | ||
d = {k: 0 for k in range(32)} | ||
for i in banks: | ||
d[i[-1]] += 1 | ||
|
||
count = 0 | ||
for key, val in d.items(): | ||
if val > 0: | ||
count += 1 | ||
|
||
print( | ||
f"Bank conflicts (Step {c}): {sorted(d.items(), key=lambda item: item[1], reverse=True)[0][1]}, banks accessed: {count}/32\n" | ||
) | ||
|
||
|
||
print("---NAIVE---") | ||
printBankConflicts(banks_naive, 32) | ||
|
||
print("\n---EXTRA COL---") | ||
printBankConflicts(banks_one_extra, 33) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <cassert> | ||
#include <cstdio> | ||
#include <cstdlib> | ||
#include <cublas_v2.h> | ||
#include <cuda_runtime.h> | ||
|
||
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N) | ||
|
||
template <const int BM, const int BN, const int BK, const int TM, const int TN> | ||
__global__ void sgemmResolveBankExtraCol(int M, int N, int K, float alpha, | ||
float *A, float *B, float beta, | ||
float *C) { | ||
const uint cRow = blockIdx.y; | ||
const uint cCol = blockIdx.x; | ||
|
||
const uint totalResultsBlocktile = BM * BN; | ||
// A thread is responsible for calculating TM*TN elements in the blocktile | ||
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN); | ||
|
||
// ResultsPerBlock / ResultsPerThread == ThreadsPerBlock | ||
assert(numThreadsBlocktile == blockDim.x); | ||
|
||
// BN/TN are the number of threads to span a column | ||
const int threadCol = threadIdx.x % (BN / TN); | ||
const int threadRow = threadIdx.x / (BN / TN); | ||
|
||
// allocate space for the current blocktile in smem | ||
__shared__ float As[BM * BK]; | ||
const int extraCols = 5; | ||
__shared__ float Bs[BK * (BN + extraCols)]; | ||
|
||
// Move blocktile to beginning of A's row and B's column | ||
A += cRow * BM * K; | ||
B += cCol * BN; | ||
C += cRow * BM * N + cCol * BN; | ||
|
||
// calculating the indices that this thread will load into SMEM | ||
// we'll load 128bit / 32bit = 4 elements per thread at each step | ||
const uint innerRowA = threadIdx.x / (BK / 4); | ||
const uint innerColA = threadIdx.x % (BK / 4); | ||
// calculates the number of rows of As that are being loaded in a single step | ||
// by a single block | ||
const uint rowStrideA = (numThreadsBlocktile * 4) / BK; | ||
const uint innerRowB = threadIdx.x / (BN / 4); | ||
const uint innerColB = threadIdx.x % (BN / 4); | ||
// for both As and Bs we want each load to span the full column-width, for | ||
// better GMEM coalescing (as opposed to spanning full row-width and iterating | ||
// across columns) | ||
const uint rowStrideB = numThreadsBlocktile / (BN / 4); | ||
|
||
// allocate thread-local cache for results in registerfile | ||
float threadResults[TM * TN] = {0.0}; | ||
float regM[TM] = {0.0}; | ||
float regN[TN] = {0.0}; | ||
|
||
// outer-most loop over block tiles | ||
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { | ||
// populate the SMEM caches | ||
// transpose A while loading it | ||
float4 tmp = | ||
reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0]; | ||
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x; | ||
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y; | ||
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z; | ||
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w; | ||
|
||
tmp = reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0]; | ||
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 0] = tmp.x; | ||
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 1] = tmp.y; | ||
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 2] = tmp.z; | ||
Bs[innerRowB * (BN + extraCols) + innerColB * 4 + 3] = tmp.w; | ||
__syncthreads(); | ||
|
||
// advance blocktile | ||
A += BK; // move BK columns to right | ||
B += BK * N; // move BK rows down | ||
|
||
// calculate per-thread results | ||
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { | ||
// block into registers | ||
for (uint i = 0; i < TM; ++i) { | ||
regM[i] = As[dotIdx * BM + threadRow * TM + i]; | ||
} | ||
for (uint i = 0; i < TN; ++i) { | ||
regN[i] = Bs[dotIdx * (BN + extraCols) + threadCol * TN + i]; | ||
} | ||
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { | ||
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { | ||
threadResults[resIdxM * TN + resIdxN] += | ||
regM[resIdxM] * regN[resIdxN]; | ||
} | ||
} | ||
} | ||
__syncthreads(); | ||
} | ||
|
||
// write out the results | ||
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { | ||
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { | ||
// load C vector into registers | ||
float4 tmp = reinterpret_cast<float4 *>( | ||
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; | ||
// perform GEMM update in reg | ||
tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; | ||
tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; | ||
tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; | ||
tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; | ||
// write back | ||
reinterpret_cast<float4 *>( | ||
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = | ||
tmp; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <cassert> | ||
#include <cstdio> | ||
#include <cstdlib> | ||
#include <cublas_v2.h> | ||
#include <cuda_runtime.h> | ||
|
||
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N) | ||
|
||
template <const int BM, const int BN, const int BK, const int TM, const int TN> | ||
__global__ void sgemmGeneralize(int M, int N, int K, float alpha, float *A, | ||
float *B, float beta, float *C) { | ||
const uint cRow = blockIdx.y; | ||
const uint cCol = blockIdx.x; | ||
|
||
const uint totalResultsBlocktile = BM * BN; | ||
// A thread is responsible for calculating TM*TN elements in the blocktile | ||
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN); | ||
|
||
// ResultsPerBlock / ResultsPerThread == ThreadsPerBlock | ||
assert(numThreadsBlocktile == blockDim.x); | ||
|
||
// BN/TN are the number of threads to span a column | ||
const int threadCol = threadIdx.x % (BN / TN); | ||
const int threadRow = threadIdx.x / (BN / TN); | ||
|
||
// allocate space for the current blocktile in smem | ||
__shared__ float As[BM * BK]; | ||
__shared__ float Bs[BK * BN]; | ||
|
||
// Move blocktile to beginning of A's row and B's column | ||
A += cRow * BM * K; | ||
B += cCol * BN; | ||
C += cRow * BM * N + cCol * BN; | ||
|
||
// calculating the indices that this thread will load into SMEM | ||
// we'll load 128bit / 32bit = 4 elements per thread at each step | ||
const uint innerRowA = threadIdx.x / (BK / 4); | ||
const uint innerColA = threadIdx.x % (BK / 4); | ||
const uint rowStrideA = (numThreadsBlocktile * 4) / BK; | ||
const uint innerRowB = threadIdx.x / (BN / 4); | ||
const uint innerColB = threadIdx.x % (BN / 4); | ||
const uint rowStrideB = numThreadsBlocktile / (BN / 4); | ||
|
||
// allocate thread-local cache for results in registerfile | ||
float threadResults[TM * TN] = {0.0}; | ||
float regM[TM] = {0.0}; | ||
float regN[TN] = {0.0}; | ||
|
||
// outer-most loop over block tiles | ||
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { | ||
// populate the SMEM caches | ||
// transpose A while loading it | ||
float4 tmp = | ||
reinterpret_cast<float4 *>(&A[innerRowA * K + innerColA * 4])[0]; | ||
As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x; | ||
As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y; | ||
As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z; | ||
As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w; | ||
|
||
reinterpret_cast<float4 *>(&Bs[innerRowB * BN + innerColB * 4])[0] = | ||
reinterpret_cast<float4 *>(&B[innerRowB * N + innerColB * 4])[0]; | ||
__syncthreads(); | ||
|
||
// advance blocktile | ||
A += BK; // move BK columns to right | ||
B += BK * N; // move BK rows down | ||
|
||
// calculate per-thread results | ||
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { | ||
// block into registers | ||
for (uint i = 0; i < TM; ++i) { | ||
regM[i] = As[dotIdx * BM + threadRow * TM + i]; | ||
} | ||
for (uint i = 0; i < TN; ++i) { | ||
regN[i] = Bs[dotIdx * BN + threadCol * TN + i]; | ||
} | ||
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { | ||
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { | ||
threadResults[resIdxM * TN + resIdxN] += | ||
regM[resIdxM] * regN[resIdxN]; | ||
} | ||
} | ||
} | ||
__syncthreads(); | ||
} | ||
|
||
// write out the results | ||
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { | ||
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { | ||
// load C vector into registers | ||
float4 tmp = reinterpret_cast<float4 *>( | ||
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; | ||
// perform GEMM update in reg | ||
tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; | ||
tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; | ||
tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; | ||
tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; | ||
// write back | ||
reinterpret_cast<float4 *>( | ||
&C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = | ||
tmp; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters