Skip to content

Commit

Permalink
Fix cc
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen committed Feb 24, 2024
1 parent 5e2913a commit 365f8b3
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def get_generator_flag():
return generator_flag


def get_compute_capabilities():
def get_compute_capabilities(
compute_capabilities={75, 80, 86, 89, 90}
):
capability_flags = []

if CUDA_VERSION:
Expand All @@ -103,7 +105,6 @@ def get_compute_capabilities():
)

# Figure out compute capability
compute_capabilities = {75, 80, 86, 89, 90}
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]

Expand Down Expand Up @@ -175,10 +176,23 @@ def get_extra_link_args():
"awq_ext/vllm/moe_alig_block.cu",
"awq_ext/vllm/activation.cu",
"awq_ext/vllm/topk_softmax_kernels.cu",
],
extra_compile_args=extra_compile_args,
)
)

# only compatible with ampere
arch_flags = get_compute_capabilities({80, 86, 89, 90})
extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags)

extensions.append(
CUDAExtension(
"awq_v2_ext",
[
"awq_ext/quantization_new/gemv/gemv_cuda.cu",
"awq_ext/quantization_new/gemm/gemm_cuda.cu",
],
extra_compile_args=extra_compile_args,
extra_compile_args=extra_compile_args_v2,
)
)

Expand Down

0 comments on commit 365f8b3

Please sign in to comment.