forked from siboehm/SGEMM_CUDA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <cassert> | ||
#include <cstdio> | ||
#include <cstdlib> | ||
#include <cublas_v2.h> | ||
#include <cuda_runtime.h> | ||
|
||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) | ||
|
||
/* | ||
* @tparam BM The threadblock size for M dimension SMEM caching. | ||
* @tparam BN The threadblock size for N dimension SMEM caching. | ||
* @tparam BK The threadblock size for K dimension SMEM caching. | ||
* @tparam WM M dim of continuous tile computed by each warp | ||
* @tparam WN N dim of continuous tile computed by each warp | ||
* @tparam WMITER The number of subwarp tiling steps in M dimension. | ||
* @tparam WNITER The number of subwarp tiling steps in N dimension. | ||
* @tparam TM The per-thread tile size for M dimension. | ||
* @tparam TN The per-thread tile size for N dimension. | ||
*/ | ||
template <const int BM, const int BN, const int BK, const int WM, const int WN, | ||
const int WNITER, const int TM, const int TN, const int NUM_THREADS> | ||
__global__ void __launch_bounds__(NUM_THREADS) | ||
sgemmDoubleBuffering(int M, int N, int K, float alpha, float *A, float *B, | ||
float beta, float *C) { | ||
const uint cRow = blockIdx.y; | ||
const uint cCol = blockIdx.x; | ||
|
||
// Placement of the warp in the threadblock tile | ||
const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in | ||
const uint warpCol = warpIdx % (BN / WN); | ||
const uint warpRow = warpIdx / (BN / WN); | ||
|
||
// size of the warp subtile | ||
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); | ||
constexpr uint WSUBM = WM / WMITER; // 64/2=32 | ||
constexpr uint WSUBN = WN / WNITER; // 32/2=16 | ||
|
||
// Placement of the thread in the warp subtile | ||
const uint threadIdxInWarp = threadIdx.x % WARPSIZE; // [0, 31] | ||
const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4) | ||
const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4 | ||
|
||
// allocate space for the current blocktile in SMEM | ||
__shared__ float As[BM * BK]; | ||
__shared__ float Bs[BK * BN]; | ||
|
||
// Move blocktile to beginning of A's row and B's column | ||
A += cRow * BM * K; | ||
B += cCol * BN; | ||
// Move C_ptr to warp's output tile | ||
C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; | ||
|
||
// calculating the indices that this thread will load into SMEM | ||
// we'll load 128bit / 32bit = 4 elements per thread at each step | ||
const uint innerRowA = threadIdx.x / (BK / 4); | ||
const uint innerColA = threadIdx.x % (BK / 4); | ||
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; | ||
const uint innerRowB = threadIdx.x / (BN / 4); | ||
const uint innerColB = threadIdx.x % (BN / 4); | ||
constexpr uint rowStrideB = NUM_THREADS / (BN / 4); | ||
|
||
// allocate thread-local cache for results in registerfile | ||
float threadResults[WMITER * TM * WNITER * TN] = {0.0}; | ||
// we cache into registers on the warptile level | ||
float regM[WMITER * TM] = {0.0}; | ||
float regN[WNITER * TN] = {0.0}; | ||
|
||
// outer-most loop over block tiles | ||
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { | ||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { | ||
float4 tmp = reinterpret_cast<float4 *>( | ||
&A[(innerRowA + offset) * K + innerColA * 4])[0]; | ||
// transpose A while storing it | ||
As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x; | ||
As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y; | ||
As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z; | ||
As[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.w; | ||
} | ||
|
||
for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { | ||
reinterpret_cast<float4 *>( | ||
&Bs[(innerRowB + offset) * BN + innerColB * 4])[0] = | ||
reinterpret_cast<float4 *>( | ||
&B[(innerRowB + offset) * N + innerColB * 4])[0]; | ||
} | ||
__syncthreads(); | ||
|
||
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { | ||
// populate registers for whole warptile | ||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { | ||
for (uint i = 0; i < TM; ++i) { | ||
regM[wSubRowIdx * TM + i] = | ||
As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + | ||
threadRowInWarp * TM + i]; | ||
} | ||
} | ||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { | ||
for (uint i = 0; i < TN; ++i) { | ||
regN[wSubColIdx * TN + i] = | ||
Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + | ||
threadColInWarp * TN + i]; | ||
} | ||
} | ||
|
||
// execute warptile matmul | ||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { | ||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { | ||
// calculate per-thread results | ||
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { | ||
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { | ||
threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + | ||
(wSubColIdx * TN) + resIdxN] += | ||
regM[wSubRowIdx * TM + resIdxM] * | ||
regN[wSubColIdx * TN + resIdxN]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
A += BK; // move BK columns to right | ||
B += BK * N; // move BK rows down | ||
__syncthreads(); | ||
} | ||
|
||
// write out the results | ||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { | ||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { | ||
// move C pointer to current warp subtile | ||
float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; | ||
for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { | ||
for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { | ||
// load C vector into registers | ||
float4 tmp = reinterpret_cast<float4 *>( | ||
&C_interim[(threadRowInWarp * TM + resIdxM) * N + | ||
threadColInWarp * TN + resIdxN])[0]; | ||
// perform GEMM update in reg | ||
const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + | ||
wSubColIdx * TN + resIdxN; | ||
tmp.x = alpha * threadResults[i + 0] + beta * tmp.x; | ||
tmp.y = alpha * threadResults[i + 1] + beta * tmp.y; | ||
tmp.z = alpha * threadResults[i + 2] + beta * tmp.z; | ||
tmp.w = alpha * threadResults[i + 3] + beta * tmp.w; | ||
// write back | ||
reinterpret_cast<float4 *>( | ||
&C_interim[(threadRowInWarp * TM + resIdxM) * N + | ||
threadColInWarp * TN + resIdxN])[0] = tmp; | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters