Skip to content

Commit

Permalink
Migrate USE_MAGMA config macro to ATen (pytorch#66390)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#66390

Test Plan: Imported from OSS

Reviewed By: malfet, bdhirsh

Differential Revision: D31547712

Pulled By: ngimel

fbshipit-source-id: 1b2ebc0d5b5d2199029274eabdd014f343cfbdd3
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Oct 14, 2021
1 parent e75de4f commit 30d9fd9
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 57 deletions.
11 changes: 1 addition & 10 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ header_template_rule(
substitutions = {
"@AT_CUDNN_ENABLED@": "1",
"@AT_ROCM_ENABLED@": "0",
"@AT_MAGMA_ENABLED@": "0",
"@NVCC_FLAGS_EXTRA@": "",
},
)
Expand All @@ -537,15 +538,6 @@ header_template_rule(
},
)

header_template_rule(
name = "aten_src_THC_THCGeneral",
src = "aten/src/THC/THCGeneral.h.in",
out = "aten/src/THC/THCGeneral.h",
substitutions = {
"#cmakedefine USE_MAGMA": "",
},
)

cc_library(
name = "aten_headers",
hdrs = [
Expand All @@ -572,7 +564,6 @@ cc_library(
deps = [
":c10_headers",
":aten_src_TH_THGeneral",
":aten_src_THC_THCGeneral",
],
)

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set_bool(AT_BUILD_WITH_BLAS USE_BLAS)
set_bool(AT_BUILD_WITH_LAPACK USE_LAPACK)
set_bool(AT_BLAS_F2C BLAS_F2C)
set_bool(AT_BLAS_USE_CBLAS_DOT BLAS_USE_CBLAS_DOT)
set_bool(AT_MAGMA_ENABLED USE_MAGMA)
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)

configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/cuda/CUDAConfig.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,11 @@

#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@

// Needed for hipMAGMA to correctly identify implementation
#if (AT_ROCM_ENABLED() && AT_MAGMA_ENABLED())
#define HAVE_HIP 1
#endif

#define NVCC_FLAGS_EXTRA "@NVCC_FLAGS_EXTRA@"
6 changes: 3 additions & 3 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <ATen/cudnn/cudnn-wrapper.h>
#endif

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
#include <magma_v2.h>
#endif

Expand Down Expand Up @@ -118,7 +118,7 @@ bool CUDAHooks::hasCUDA() const {
}

bool CUDAHooks::hasMAGMA() const {
#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
return true;
#else
return false;
Expand Down Expand Up @@ -337,7 +337,7 @@ std::string CUDAHooks::showConfig() const {
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
#endif

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
#endif

Expand Down
56 changes: 27 additions & 29 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
#include <ATen/native/cuda/BatchLinearAlgebraLib.h>
#include <ATen/native/cpu/zmath.h>

#include <THC/THC.h> // for USE_MAGMA

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
#include <magma_types.h>
#include <magma_v2.h>

Expand All @@ -28,7 +26,7 @@ const bool use_magma_ = false;
namespace at {
namespace native {

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
template<class scalar_t>
void magmaSolve(
magma_int_t n, magma_int_t nrhs, scalar_t* dA, magma_int_t ldda,
Expand Down Expand Up @@ -1233,7 +1231,7 @@ magma_trans_t to_magma(TransposeType trans) {
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
}
} // anonymous namespace
#endif // USE_MAGMA
#endif // AT_MAGMA_ENABLED()

#define ALLOCATE_ARRAY(name, type, size) \
auto storage_##name = pin_memory<type>(size); \
Expand All @@ -1243,7 +1241,7 @@ magma_trans_t to_magma(TransposeType trans) {

template <typename scalar_t>
static void apply_solve(Tensor& b, Tensor& A, Tensor& infos_out) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -1338,7 +1336,7 @@ For more information see MAGMA's documentation for GETRI and GETRF routines.
*/
template <typename scalar_t>
static void apply_batched_inverse(Tensor& self, Tensor& self_inv, Tensor& infos_lu, Tensor& infos_getri) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("inverse: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -1412,7 +1410,7 @@ AT_ERROR("inverse: MAGMA library not found in "

template <typename scalar_t>
static void apply_single_inverse(Tensor& self, Tensor& info_lu, Tensor& info_getri) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("inverse: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -1510,7 +1508,7 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in

template <typename scalar_t>
static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("cholesky_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -1606,7 +1604,7 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp

template <typename scalar_t>
static void apply_cholesky(const Tensor& self, bool upper, const Tensor& info) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.linalg.cholesky on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -1721,7 +1719,7 @@ For more information see MAGMA's documentation for POTRS routine.
*/
template <typename scalar_t>
static void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "cholesky_inverse: MAGMA library not found in compilation. Please rebuild with MAGMA.");
#else
// magmaCholeskyInverse (magma_dpotri_gpu) is slow because internally
Expand Down Expand Up @@ -1800,7 +1798,7 @@ REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
*/
template <typename scalar_t>
static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.lu on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -1861,7 +1859,7 @@ static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, con
*/
template <typename scalar_t>
static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.lu on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -1958,7 +1956,7 @@ REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu);

template <typename scalar_t>
static void apply_triangular_solve_batched_magma(Tensor& A, Tensor& b, bool left, bool upper, TransposeType transpose, bool unitriangular) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("triangular_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -2033,7 +2031,7 @@ void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, Transp
if (batchCount(A) <= 8 && A.size(-1) >= 64) {
triangular_solve_cublas(A, B, left, upper, transpose, unitriangular);
} else {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular);
#else
// cuBLAS batched is faster than MAGMA batched up until 512x512, after that MAGMA is faster
Expand All @@ -2042,7 +2040,7 @@ void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, Transp
} else {
triangular_solve_batched_magma(A, B, left, upper, transpose, unitriangular);
}
#endif // USE_MAGMA
#endif // AT_MAGMA_ENABLED()
}
}

Expand Down Expand Up @@ -2082,7 +2080,7 @@ REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel);

template <typename scalar_t>
static void apply_geqrf(const Tensor& input, const Tensor& tau) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.geqrf on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -2160,7 +2158,7 @@ REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel);
template <typename scalar_t>
static void apply_qr(Tensor& Q, Tensor& R, int64_t q_size_minus_2, int64_t r_size_minus_1, int64_t n_columns,
bool compute_q) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("qr: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -2277,7 +2275,7 @@ std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda(const Tensor& input, c10::stri

template <typename scalar_t>
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.linalg.eigh/eigvalsh on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -2442,7 +2440,7 @@ REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
template <typename scalar_t>
static void apply_eig(const Tensor& self, bool eigenvectors, Tensor& out_eigvals, Tensor& out_eigvecs,
int64_t *info_ptr) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA.");
#else
Expand Down Expand Up @@ -2532,7 +2530,7 @@ For more information see MAGMA's documentation for GEEV routine.
*/
template <typename scalar_t>
void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA.");
#else
Expand Down Expand Up @@ -2613,7 +2611,7 @@ REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
template<typename scalar_t>
static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
char jobchar, std::vector<int64_t>& infos) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
AT_ERROR("svd: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -2747,7 +2745,7 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som
*/
template <typename scalar_t>
static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -2801,7 +2799,7 @@ static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const
*/
template <typename scalar_t>
static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
Expand Down Expand Up @@ -2911,7 +2909,7 @@ REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_dispatch);

template <typename scalar_t>
static void apply_gels(const Tensor& a, Tensor& b, Tensor& infos) {
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -3063,7 +3061,7 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul
"Please rebuild with cuSOLVER.");
#endif
} else { // m >= n
#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
// MAGMA is not available we can either use cuBLAS or cuSOLVER here
// the batched vs looped dispatch is implemented based on the following performance results
// https://github.com/pytorch/pytorch/pull/54725#issuecomment-832234456
Expand All @@ -3083,7 +3081,7 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul
// if both MAGMA and cuSOLVER are available this would call cuSOLVER
// MAGMA is called if cuSOLVER is not available
gels_looped(a, b, infos);
#endif // USE_MAGMA
#endif // AT_MAGMA_ENABLED()
}
}

Expand All @@ -3106,7 +3104,7 @@ std::tuple<Tensor, Tensor> legacy_lstsq_cuda(const Tensor &B, const Tensor &A) {
"X = torch.linalg.lstsq(A, B).solution"
);

#ifndef USE_MAGMA
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
Expand Down Expand Up @@ -3145,7 +3143,7 @@ std::tuple<Tensor, Tensor> legacy_lstsq_cuda(const Tensor &B, const Tensor &A) {

TORCH_CHECK(info == 0, "MAGMA gels : Argument %d : illegal value", -info);
return std::tuple<Tensor, Tensor>(B_working, A_working);
#endif // USE_MAGMA
#endif // AT_MAGMA_ENABLED()
}

std::tuple<Tensor&, Tensor&> legacy_lstsq_out_cuda(
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/MiscUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
#include <ATen/ATen.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <THC/THC.h> // for USE_MAGMA

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
#include <magma_types.h>
#include <magma_v2.h>
#endif

namespace at {
namespace native {

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()

// RAII for a MAGMA Queue
struct MAGMAQueue {
Expand Down
5 changes: 1 addition & 4 deletions aten/src/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE}
"${CMAKE_CURRENT_BINARY_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}"
PARENT_SCOPE)

configure_file(THCGeneral.h.in "${CMAKE_CURRENT_BINARY_DIR}/THCGeneral.h")

set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
${CMAKE_CURRENT_SOURCE_DIR}/THCGeneral.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cpp
Expand All @@ -18,7 +15,7 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}

install(FILES
THC.h
${CMAKE_CURRENT_BINARY_DIR}/THCGeneral.h
THCGeneral.h
THCGeneral.hpp
THCSleep.h
THCStorage.h
Expand Down
6 changes: 0 additions & 6 deletions aten/src/THC/THCGeneral.h.in → aten/src/THC/THCGeneral.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@

#include <cusparse.h>

#cmakedefine USE_MAGMA
/* Needed for hipMAGMA to correctly identify implementation */
#if defined(USE_MAGMA) && defined(USE_ROCM)
#define HAVE_HIP 1
#endif

#ifndef THAssert
#define THAssert(exp) \
do { \
Expand Down
5 changes: 3 additions & 2 deletions aten/src/THC/THCTensorMathMagma.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include <THC/THCGeneral.h>
#include <ATen/cuda/detail/CUDAHooks.h>
#include <ATen/cuda/CUDAConfig.h>

#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
#include <magma_v2.h>
#endif

namespace {
void _THCMagma_init() {
#ifdef USE_MAGMA
#if AT_MAGMA_ENABLED()
magma_init();
#endif
}
Expand Down

0 comments on commit 30d9fd9

Please sign in to comment.