Skip to content

Commit

Permalink
Add compute capability 8.9 to default targets (vllm-project#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 22, 2023
1 parent eedac9d commit a41c204
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.")
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")


def get_nvcc_cuda_version(cuda_dir: str) -> Version:
Expand Down Expand Up @@ -55,6 +55,14 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
compute_capabilities.remove(89)
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
Expand All @@ -65,6 +73,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90)

# Add target compute capabilities to NVCC flags.
Expand Down

0 comments on commit a41c204

Please sign in to comment.