Skip to content

Commit

Permalink
Fix casting
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen committed Feb 24, 2024
1 parent c00bdad commit f9414e0
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions awq_ext/quantization_new/gemv/gemv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,27 @@
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128


static inline __device__ float to_float(half src)
{
return __half2float(src);
}

static inline __device__ float to_float(float src)
{
return src;
}

static inline __device__ half to_half(float src)
{
return __float2half(src);
}

static inline __device__ half to_half(half src)
{
return src;
}

// Reduce sum within the warp using the tree reduction algorithm.
template <int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4])
Expand All @@ -42,7 +63,7 @@ __device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast<float>(psum[i]);
fpsum[i] = to_float(psum[i]);
}

#pragma unroll
Expand Down Expand Up @@ -97,7 +118,7 @@ __global__ void gemv_kernel(

half psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<half>(0.f);
psum[i] = to_half(0.f);

extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
Expand Down Expand Up @@ -199,7 +220,7 @@ __global__ void gemv_kernel(
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half>(acc);
outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc);
}
}

Expand Down

0 comments on commit f9414e0

Please sign in to comment.