Skip to content

Commit

Permalink
use cuda libraries of nvidia in CINN (PaddlePaddle#61595)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix
  • Loading branch information
risemeup1 committed Feb 8, 2024
1 parent adccc98 commit 345335d
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions python/setup_cinn.py.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import errno
import os
import re
import sys
import shutil
import errno
import platform
import subprocess
from contextlib import contextmanager
from setuptools import setup

Expand Down Expand Up @@ -156,7 +158,7 @@ for lib in cinnlibs:
set_rpath(os.path.join(libs_path, libname) , '$ORIGIN/')
package_data['cinn.libs'].append(libname)

set_rpath('${CMAKE_BINARY_DIR}/python/cinn/core_api.so', '$ORIGIN/libs/')
set_rpath('${CMAKE_BINARY_DIR}/python/cinn/core_api.so', '$ORIGIN/../nvidia/cuda_runtime/lib:$ORIGIN/../nvidia/cuda_nvrtc/lib:$ORIGIN/../nvidia/cudnn/lib:$ORIGIN/../nvidia/nvtx/lib:$ORIGIN/../nvidia/cublas/lib:$ORIGIN/../nvidia/curand/lib:$ORIGIN/../nvidia/cusolver/lib:$ORIGIN/libs/')

def git_commit():
try:
Expand All @@ -177,6 +179,53 @@ packages = ["cinn",
"cinn.runtime"
]

install_requires=[]

if platform.system() == 'Linux' and platform.machine() == 'x86_64':
paddle_cuda_install_requirements = os.getenv(
"PADDLE_CUDA_INSTALL_REQUIREMENTS", None
)
if paddle_cuda_install_requirements is not None:
PADDLE_CUDA_INSTALL_REQUIREMENTS = {
"V11": (
"nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu11==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
"V12": (
"nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
}
try:
output = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
version_line = [line for line in output.split('\n') if 'release' in line][0]
version = version_line.split(' ')[-1].split(',')[0]
cuda_major_version = version.split('.')[0]
except Exception as e:
raise ValueError("CUDA not found")

install_requires.append(PADDLE_CUDA_INSTALL_REQUIREMENTS[cuda_major_version].split("|"))



with redirect_stdout():
setup(
name='${PACKAGE_NAME}',
Expand All @@ -187,5 +236,6 @@ with redirect_stdout():
url='https://github.com/PaddlePaddle/Paddle',
license='Apache Software License',
packages=packages,
install_requires=install_requires,
package_data=package_data
)

0 comments on commit 345335d

Please sign in to comment.