Skip to content

Enabling PyTorch CUDA extensions for ROCm

mwootton edited this page Jun 22, 2023 · 10 revisions

Requirements:

  1. Extension build script, eg. setup.py, should use the torch.utils.cpp_extension.CUDAExtension class to build the extension.
        or
    Extension should be compiled Just-In-Time using torch.utils.cpp_extension.load().
  2. PyTorch version > 1.8 or newer than Dec 2 commit 5f62308.

How hipification works for PyTorch extensions:

CUDAExtension uses "hipification" to translate CUDA sources to HIP so that they can be used with ROCm.

  • Invoked for ROCm builds automatically under-the-hood by CUDAExtension class
  • Source files get hipified out-of-place, with the following general renaming rules for the hipified output files:
    • Paths containing 'cuda' get renamed to 'hip' eg. csrc/include/cuda → csrc/include/hip
    • Filenames are modified by replacing cuda → hip and CUDA → HIP e.g., normalize_kernels_cuda.h → normalize_kernels_hip.h
    • .cu files get renamed to .hip eg. normalize_kernels.cu → normalize_kernels.hip
    • Files that are modified but do not fall into the above categories will be renamed with a _hip suffix, e.g., foo.c → foo_hip.c
    • Files that are not modified are not renamed or moved in any way.
  • The hipify script does text replacement of CUDA terms with equivalent HIP terms according to the mappings specified in: https://github.com/pytorch/pytorch/blob/master/torch/utils/hipify/cuda_to_hip_mappings.py

Steps to enable PyTorch CUDA extensions for ROCm:

  1. Use helper code such as below to detect if we're building for ROCm:
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
is_rocm_pytorch = False
if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
     from torch.utils.cpp_extension import ROCM_HOME
     is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
  1. Adjust the compiler flags in nvcc key of extra_compile_args flag to remove any nvcc-specific flags (ie. not supported by hipcc) when building for ROCm. Eg:
nvcc_flags=['-O3'] + version_dependent_macros
if not is_rocm_pytorch:
    nvcc_flags.extend(['-lineinfo', '--use_fast_math'])
CUDAExtension(name='myExt',
              ...
              extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                  'nvcc': nvcc_flags}
              ...

CUDAExtension will handle everything else under-the-hood, such as hipifying the extension source files and using the appropriate additional compiler flags needed for hipcc.

  1. However, in case you find that the hipification did not replace all CUDA terms with HIP equivalents, you might need to manually edit the source files to introduce some ifdef'd code for ROCm. Eg:
#ifdef __HIP_PLATFORM_AMD__
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
#else
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
#endif