Skip to content

Commit

Permalink
Add minimum capability requirement for AWQ (vllm-project#1064)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Sep 18, 2023
1 parent 268dc39 commit 2a5e962
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 2 deletions.
8 changes: 8 additions & 0 deletions csrc/quantization/awq/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor

#pragma once

namespace vllm {
namespace awq {

__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
uint4 result;

uint32_t* h = reinterpret_cast<uint32_t*>(&result);
Expand Down Expand Up @@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

return result;
#endif
}

} // namespace awq
} // namespace vllm
18 changes: 16 additions & 2 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq

#include <cuda_fp16.h>

namespace vllm {
namespace awq {

// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
Expand All @@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {

__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
Expand Down Expand Up @@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
}
#endif
}


__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
Expand Down Expand Up @@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
}
}
#endif
}

} // namespace awq
} // namespace vllm

// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
Expand Down Expand Up @@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
Expand All @@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/quantization_utils/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def get_name(cls) -> str:
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]

@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Ampere or newer GPUs.
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
return [
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/quantization_utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError

@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError

@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
Expand Down

0 comments on commit 2a5e962

Please sign in to comment.