Skip to content

Commit

Permalink
make sure that our error handling runs with the GIL enabled (pytorch#…
Browse files Browse the repository at this point in the history
…92848)

Fixes pytorch#92684

I checked the other use case of this API and they never release the GIL

Pull Request resolved: pytorch#92848
Approved by: https://github.com/ngimel
  • Loading branch information
albanD authored and pytorchmergebot committed Jan 24, 2023
1 parent abe6488 commit d8aa68c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
33 changes: 33 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import threading
import unittest
import warnings
import subprocess
from random import randint

import torch
Expand Down Expand Up @@ -3334,6 +3335,38 @@ def test_graph_capture_simple(self):

self.assertTrue(b.sum().item() == 11000.)

@unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
def test_graph_error(self):
# We need to run this test in a separate thread as the error we trigger
# puts the cuda context in a bad state
script = """
import torch
g = torch.cuda.CUDAGraph()
try:
g.capture_begin()
except RuntimeError as e:
if "CUDA graphs must be captured on a non-default stream." in str(e):
exit(0)
else:
exit(1)
exit(2)
"""
try:
a = subprocess.check_output(
[sys.executable, '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),)
except subprocess.CalledProcessError as e:
if e.returncode == 1:
self.assertTrue(False, "Error raise by starting capture without a stream is not the expected one")
elif e.returncode == 2:
self.assertTrue(False, "Error raised by starting capture without a stream was not caught")

@unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
Expand Down
26 changes: 22 additions & 4 deletions torch/csrc/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,23 @@ template <typename Func, size_t i>
using Arg = typename invoke_traits<Func>::template arg<i>::type;

template <typename Func, size_t... Is>
auto wrap_pybind_function_impl_(Func&& f, std::index_sequence<Is...>) {
auto wrap_pybind_function_impl_(
Func&& f,
std::index_sequence<Is...>,
bool release_gil) {
using result_type = typename invoke_traits<Func>::result_type;
namespace py = pybind11;

// f=f is needed to handle function references on older compilers
return [f = std::forward<Func>(f)](Arg<Func, Is>... args) -> result_type {
return [f = std::forward<Func>(f),
release_gil](Arg<Func, Is>... args) -> result_type {
HANDLE_TH_ERRORS
return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
if (release_gil) {
py::gil_scoped_release no_gil;
return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
} else {
return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
}
END_HANDLE_TH_ERRORS_PYBIND
};
}
Expand All @@ -398,7 +407,16 @@ template <typename Func>
auto wrap_pybind_function(Func&& f) {
using traits = invoke_traits<Func>;
return torch::detail::wrap_pybind_function_impl_(
std::forward<Func>(f), std::make_index_sequence<traits::arity>{});
std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, false);
}

// Wrap a function with TH error, warning handling and releases the GIL.
// Returns a function object suitable for registering with pybind11.
template <typename Func>
auto wrap_pybind_function_no_gil(Func&& f) {
using traits = invoke_traits<Func>;
return torch::detail::wrap_pybind_function_impl_(
std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, true);
}

} // namespace torch
29 changes: 12 additions & 17 deletions torch/csrc/cuda/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,32 @@ void THCPGraph_init(PyObject* module) {
// docs aren't clear. But it works.
.def(
"capture_begin",
torch::wrap_pybind_function(&at::cuda::CUDAGraph::capture_begin),
py::call_guard<py::gil_scoped_release>(),
torch::wrap_pybind_function_no_gil(
&at::cuda::CUDAGraph::capture_begin),
py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
.def(
"capture_end",
torch::wrap_pybind_function(&at::cuda::CUDAGraph::capture_end),
py::call_guard<py::gil_scoped_release>())
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
.def(
"replay",
torch::wrap_pybind_function(&at::cuda::CUDAGraph::replay),
py::call_guard<py::gil_scoped_release>())
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
.def(
"reset",
torch::wrap_pybind_function(&at::cuda::CUDAGraph::reset),
py::call_guard<py::gil_scoped_release>())
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
.def(
"pool",
torch::wrap_pybind_function(&at::cuda::CUDAGraph::pool),
py::call_guard<py::gil_scoped_release>())
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
.def(
"debug_dump",
torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump),
py::call_guard<py::gil_scoped_release>())
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump))
.def(
"enable_debug_mode",
torch::wrap_pybind_function(
&::at::cuda::CUDAGraph::enable_debug_mode),
py::call_guard<py::gil_scoped_release>())
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::enable_debug_mode))
.def(
"debug_dump",
torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump),
py::call_guard<py::gil_scoped_release>(),
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump),
py::arg("debug_path"));
}

0 comments on commit d8aa68c

Please sign in to comment.