Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimum capability requirement for AWQ #1064

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/quantization/awq/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor

#pragma once

namespace vllm {
namespace awq {
Comment on lines +14 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this namespace needed?

Copy link
Collaborator Author

@WoosukKwon WoosukKwon Sep 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are optional, but for better coding convention. From google cpp style guide:

With few exceptions, place code in a namespace.

Namespace prevents naming conflicts, so it's pretty useful for external code like the AWQ kernels.


__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
uint4 result;

uint32_t* h = reinterpret_cast<uint32_t*>(&result);
Expand Down Expand Up @@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

return result;
#endif
}

} // namespace awq
} // namespace vllm
18 changes: 16 additions & 2 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq

#include <cuda_fp16.h>

namespace vllm {
namespace awq {

// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
Expand All @@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {

__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
Expand Down Expand Up @@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
}
#endif
}


__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
Expand Down Expand Up @@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
}
}
#endif
}

} // namespace awq
} // namespace vllm

// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
Expand Down Expand Up @@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
Expand All @@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
Comment on lines +71 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the assert false in C++ if we have the check here?

Copy link
Member

@zhuohan123 zhuohan123 Sep 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw the comments in the PR. Can we just change setup.py instead of the C++ files?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's problematic when we want to build the wheel for all GPU architectures (e.g., for pypi publication or building docker image). In such a case, we cannot selectively include the extension according to the architecture. Therefore, I believe this is an easier solution, and in fact we already used this kind of guard for bfloat16 attention kernels, which do not support Turing and Volta GPUs.

supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/quantization_utils/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def get_name(cls) -> str:
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]

@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Ampere or newer GPUs.
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
return [
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/quantization_utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError

@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.

E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError

@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
Expand Down
Loading