Skip to content

Commit

Permalink
Add a bit of asm
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Boehm committed Mar 12, 2023
1 parent 7c15d6d commit 49b63c1
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/kernels/10_kernel_warptiling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ const int WARPSIZE = 32; // warpSize is not constexpr
namespace wt {
template <const int BM, const int BN, const int BK, const int rowStrideA,
const int rowStrideB>
__device__ void loadFromGmem(int N, int K, float *A, float *B, float *As,
float *Bs, int innerRowA, int innerColA,
__device__ void loadFromGmem(int N, int K, const float *A, const float *B,
float *As, float *Bs, int innerRowA, int innerColA,
int innerRowB, int innerColB) {
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
float4 tmp = reinterpret_cast<float4 *>(
const float4 tmp = reinterpret_cast<const float4 *>(
&A[(innerRowA + offset) * K + innerColA * 4])[0];
// transpose A while storing it
// float4 tmp;
// asm("ld.global.nc.v4.f32 {%0, %1, %2, %3}, [%4];"
// : "=f"(tmp.x), "=f"(tmp.y), "=f"(tmp.z), "=f"(tmp.w)
// : "l"(&A[(innerRowA + offset) * K + innerColA * 4]));
As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x;
As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y;
As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z;
Expand All @@ -29,8 +32,14 @@ __device__ void loadFromGmem(int N, int K, float *A, float *B, float *As,
for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) {
reinterpret_cast<float4 *>(
&Bs[(innerRowB + offset) * BN + innerColB * 4])[0] =
reinterpret_cast<float4 *>(
reinterpret_cast<const float4 *>(
&B[(innerRowB + offset) * N + innerColB * 4])[0];
// asm("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];"
// : "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 0]),
// "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 1]),
// "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 2]),
// "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 3])
// : "l"(&B[(innerRowB + offset) * N + innerColB * 4]));
}
}

Expand Down

0 comments on commit 49b63c1

Please sign in to comment.