Skip to content

Commit

Permalink
Fix torch cpp ext build when CPU wheel is installed but GPU card is p…
Browse files Browse the repository at this point in the history
…resent (#11608)

* Fix torch cpp ext build when CPU wheel is installed but GPU card is present

Also there is a minor improvement for ATen operator that allows both
"::op" and "aten::op" name for operators

* Fix flake8 false positive
  • Loading branch information
Thiago Crepaldi committed May 25, 2022
1 parent 147a173 commit 4272304
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ exclude =
./orttraining,
# ignore server code for now
./server,
# ignore issues from different git branches
./.git,
ignore = W503, E203
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,16 @@ class ATenOperatorCache {
}

const ATenOperator& GetOperator(const std::string& op_name, const std::string& overload_name) {
auto key = std::make_pair(op_name, overload_name);
// PyTorch ONNX converter creates ATen operators with name without domain
std::string final_op_name = op_name;
auto pos = op_name.find("::");
if (pos == std::string::npos) {
final_op_name = std::string("aten::" + op_name);
}

auto key = std::make_pair(final_op_name, overload_name);
if (ops_.find(key) == ops_.end()) {
c10::OperatorName full_name(op_name, overload_name);
c10::OperatorName full_name(final_op_name, overload_name);
auto op = torch::jit::findOperatorFor(full_name);
TORCH_INTERNAL_ASSERT(op);
ATenOperator aten_op;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from onnxruntime.training import ortmodule

from glob import glob
from shutil import copyfile
import os
import subprocess
import sys
from glob import glob
from shutil import copyfile

import torch

from onnxruntime.training import ortmodule


def _list_extensions(path):
Expand All @@ -30,25 +32,26 @@ def _list_cuda_extensions():


def _install_extension(ext_name, ext_path, cwd):
ret_code = subprocess.call(f"{sys.executable} {ext_path} build", cwd=cwd, shell=True)
ret_code = subprocess.call((sys.executable, ext_path, "build"), cwd=cwd)
if ret_code != 0:
print(f'There was an error compiling "{ext_name}" PyTorch CPP extension')
print(f"There was an error compiling '{ext_name}' PyTorch CPP extension")
sys.exit(ret_code)


def build_torch_cpp_extensions():
"""Builds PyTorch CPP extensions and returns metadata"""

"""Builds PyTorch CPP extensions and returns metadata."""
# Run this from within onnxruntime package folder
is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
is_gpu_available = torch.cuda.is_available() and (
ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None
)
os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR)

# Extensions might leverage CUDA/ROCM versions internally
os.environ["ONNXRUNTIME_CUDA_VERSION"] = (
ortmodule.ONNXRUNTIME_CUDA_VERSION if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else ""
ortmodule.ONNXRUNTIME_CUDA_VERSION if ortmodule.ONNXRUNTIME_CUDA_VERSION is not None else ""
)
os.environ["ONNXRUNTIME_ROCM_VERSION"] = (
ortmodule.ONNXRUNTIME_ROCM_VERSION if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else ""
ortmodule.ONNXRUNTIME_ROCM_VERSION if ortmodule.ONNXRUNTIME_ROCM_VERSION is not None else ""
)

############################################################################
Expand Down

0 comments on commit 4272304

Please sign in to comment.