Skip to content

Commit

Permalink
pin_memory malloc now uses existing context if available. (pytorch#22229
Browse files Browse the repository at this point in the history
)

Summary:
This is achieved by using `cuDevicePrimaryCtxGetState` as a way to check whether a primary context exists on a device. It is not too slow, from this benchmark of a single call to it on CUDA 10.1, Titan Xp, driver 415.27:
```
---------------------------------------------------------------------
Benchmark                              Time           CPU Iterations
---------------------------------------------------------------------
BM_cuDevicePrimaryCtxGetState        301 ns        301 ns    2319746
```

Commits:

1. Add `CUDAHooks::getDeviceWithPrimaryContext` which returns a device index with primary context (if exists).
    Link `c10/cuda` against `libcuda` for device API calls.
2. Use `getDeviceWithPrimaryContext` to check primary context in `pin_memory`.
    Fix `OptionalDeviceGuard` doc.
3. Refactor `test_cuda_primary_ctx.py` to support multiple tests.
    Add test for this in that file.

Fixes pytorch#21081.
Pull Request resolved: pytorch#22229

Differential Revision: D16170194

Pulled By: zou3519

fbshipit-source-id: 485a45f211b7844c9e69c63f3b3b75194a796c5d
  • Loading branch information
ssnl authored and facebook-github-bot committed Jul 16, 2019
1 parent 054c7eb commit 8482efb
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 83 deletions.
10 changes: 10 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <ATen/detail/CUDAHooksInterface.h>
Expand Down Expand Up @@ -114,6 +115,15 @@ int64_t CUDAHooks::current_device() const {
return -1;
}

bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
int ctx_is_active;
AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
return ctx_is_active == 1;
}

Allocator* CUDAHooks::getPinnedMemoryAllocator() const {
return at::cuda::getPinnedMemoryAllocator();
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCuDNN() const override;
const at::cuda::NVRTC& nvrtc() const override;
int64_t current_device() const override;
bool hasPrimaryContext(int64_t device_index) const override;
Allocator* getPinnedMemoryAllocator() const override;
bool compiledWithCuDNN() const override;
bool compiledWithMIOpen() const override;
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ namespace at { namespace cuda {
//
// ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both
// NVRTC and driver APIs. While the former is not yet suppoted for HIP, the
// later is supported and needed.
// later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext()
// used by tensor.pin_memory()).
//
// The macro below strips out certain unsupported operations on HIP from the full
// list above.
Expand Down
28 changes: 16 additions & 12 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ struct CAFFE2_API CUDAHooksInterface {

// Initialize THCState and, transitively, the CUDA state
virtual std::unique_ptr<THCState, void (*)(THCState*)> initCUDA() const {
AT_ERROR("Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
}

virtual Generator* getDefaultCUDAGenerator(DeviceIndex device_index = -1) const {
AT_ERROR("Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
}

virtual Device getDeviceFromPtr(void* data) const {
AT_ERROR("Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
}

virtual bool hasCUDA() const {
Expand All @@ -84,15 +84,19 @@ struct CAFFE2_API CUDAHooksInterface {
}

virtual const at::cuda::NVRTC& nvrtc() const {
AT_ERROR("NVRTC requires CUDA. ", CUDA_HELP);
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}

virtual int64_t current_device() const {
return -1;
}

virtual bool hasPrimaryContext(int64_t device_index) const {
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
}

virtual Allocator* getPinnedMemoryAllocator() const {
AT_ERROR("Pinned memory requires CUDA. ", CUDA_HELP);
TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
}

virtual bool compiledWithCuDNN() const {
Expand All @@ -112,32 +116,32 @@ struct CAFFE2_API CUDAHooksInterface {
}

virtual long versionCuDNN() const {
AT_ERROR("Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}

virtual std::string showConfig() const {
AT_ERROR("Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
}

virtual double batchnormMinEpsilonCuDNN() const {
AT_ERROR(
TORCH_CHECK(false,
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
}

virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual int64_t cuFFTGetPlanCacheSize(int64_t device_index) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual void cuFFTClearPlanCache(int64_t device_index) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}

virtual int getNumGPUs() const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Memory.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include <ATen/ATen.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/core/Storage.h>
#include <ATen/TensorUtils.h>

namespace at {
namespace native {
Expand Down
31 changes: 31 additions & 0 deletions aten/src/THC/THCCachingHostAllocator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <THC/THCCachingHostAllocator.h>
#include <ATen/DeviceGuard.h>
#include <ATen/detail/CUDAHooksInterface.h>


#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -40,6 +42,24 @@ static bool BlockComparator(const BlockSize& a, const BlockSize& b)
return (uintptr_t)a.ptr < (uintptr_t)b.ptr;
}

static int64_t inline get_device_index_with_primary_context() {
const auto& cuda_hooks = at::detail::getCUDAHooks();
// check current device first
int64_t current_device_index = cuda_hooks.current_device();
if (current_device_index >= 0) {
if (cuda_hooks.hasPrimaryContext(current_device_index)) {
return current_device_index;
}
}
for (int64_t device_index = 0; device_index < cuda_hooks.getNumGPUs(); device_index++) {
if (device_index == current_device_index) continue;
if (cuda_hooks.hasPrimaryContext(device_index)) {
return device_index;
}
}
return -1;
}

struct HostAllocator
{
typedef bool (*Comparison)(const BlockSize&, const BlockSize&);
Expand Down Expand Up @@ -80,6 +100,17 @@ struct HostAllocator
return cudaSuccess;
}

// Pinned memory pointers allocated by any device can be directly used by any
// other device, regardless of the current device at the time of allocation,
// since we assume unified addressing.
// So we grab any existing primary context, if available.
// See pytorch/pytorch#21081.
at::OptionalDeviceGuard device_guard;
auto primary_ctx_device_index = get_device_index_with_primary_context();
if (primary_ctx_device_index >= 0) {
device_guard.reset_device(at::Device(at::DeviceType::CUDA, primary_ctx_device_index));
}

// note that cudaHostAlloc may not touch pointer if size is 0
*ptr = 0;

Expand Down
6 changes: 3 additions & 3 deletions c10/core/DeviceGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class DeviceGuard {
* setDevice(1);
* OptionalDeviceGuard g;
* setDevice(2);
* g.set_device(3); // initializes!
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
*
* On destruction, g will reset device to 2, rather than 1.
*
Expand All @@ -118,7 +118,7 @@ class DeviceGuard {
*/
class OptionalDeviceGuard {
public:
/// Create an uninitialized guard. Set the guard later using set_device.
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() : guard_() {}

/// Initialize the guard, setting the current device to the passed Device.
Expand Down Expand Up @@ -159,7 +159,7 @@ class OptionalDeviceGuard {
}

/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device.
/// either from construction, or via reset_device.
optional<Device> current_device() const {
return guard_.current_device();
}
Expand Down
2 changes: 1 addition & 1 deletion c10/cuda/CUDAException.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "c10/util/Exception.h"
#include "c10/macros/Macros.h"
#include "cuda.h"
#include <cuda.h>

// Note [CHECK macro]
// ~~~~~~~~~~~~~~~~~~
Expand Down
59 changes: 58 additions & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import random
import contextlib
import socket
import subprocess
import time
from collections import OrderedDict
from contextlib import contextmanager
Expand Down Expand Up @@ -46,18 +47,74 @@


parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
args, remaining = parser.parse_known_args()
TEST_IN_SUBPROCESS = args.subprocess
SEED = args.seed
if not expecttest.ACCEPT:
expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)


def shell(command, cwd=None):
sys.stdout.flush()
sys.stderr.flush()
# The following cool snippet is copied from Py3 core library subprocess.call
# only the with
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
try:
return p.wait()
except KeyboardInterrupt:
# Give `p` a chance to handle KeyboardInterrupt. Without this,
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
exit_status = p.wait(timeout=5)
if exit_status is not None:
return exit_status
else:
p.kill()
raise
except: # noqa E722, copied from python core library
p.kill()
raise
finally:
# Always call p.wait() to ensure exit
p.wait()


def run_tests(argv=UNITTEST_ARGS):
unittest.main(argv=argv)
if TEST_IN_SUBPROCESS:
suite = unittest.TestLoader().loadTestsFromModule(__main__)
test_cases = []

def add_to_test_cases(suite_or_case):
if isinstance(suite_or_case, unittest.TestCase):
test_cases.append(suite_or_case)
else:
for element in suite_or_case:
add_to_test_cases(element)

add_to_test_cases(suite)
failed_tests = []
for case in test_cases:
test_case_full_name = case.id().split('.', 1)[1]
exitcode = shell([sys.executable] + argv + [test_case_full_name])
if exitcode != 0:
failed_tests.append(test_case_full_name)

assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
len(failed_tests), '\n\t'.join(failed_tests))
else:
unittest.main(argv=argv)

PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)
Expand Down
42 changes: 8 additions & 34 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch._six
from torch.utils import cpp_extension
from common_utils import TEST_WITH_ROCM
from common_utils import TEST_WITH_ROCM, shell
import torch.distributed as dist

TESTS = [
Expand Down Expand Up @@ -98,49 +98,22 @@ def print_to_stderr(message):
print(message, file=sys.stderr)


def shell(command, cwd=None):
sys.stdout.flush()
sys.stderr.flush()
# The folloing cool snippet is copied from Py3 core library subprocess.call
# only the with
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
try:
return p.wait()
except KeyboardInterrupt:
# Give `p` a chance to handle KeyboardInterrupt. Without this,
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
exit_status = p.wait(timeout=5)
if exit_status is not None:
return exit_status
else:
p.kill()
raise
except: # noqa E722, copied from python core library
p.kill()
raise
finally:
# Always call p.wait() to ensure exit
p.wait()


def run_test(executable, test_module, test_directory, options):
def run_test(executable, test_module, test_directory, options, *extra_unittest_args):
unittest_args = options.additional_unittest_args
if options.verbose:
unittest_args.append('--verbose')
# Can't call `python -m unittest test_*` here because it doesn't run code
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
argv = [test_module + '.py'] + unittest_args
argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args)

command = executable + argv
return shell(command, test_directory)


def test_cuda_primary_ctx(executable, test_module, test_directory, options):
return run_test(executable, test_module, test_directory, options, '--subprocess')


def test_cpp_extensions(executable, test_module, test_directory, options):
try:
cpp_extension.verify_ninja_availability()
Expand Down Expand Up @@ -226,6 +199,7 @@ def test_distributed(executable, test_module, test_directory, options):


CUSTOM_HANDLERS = {
'cuda_primary_ctx': test_cuda_primary_ctx,
'cpp_extensions': test_cpp_extensions,
'distributed': test_distributed,
}
Expand Down
Loading

0 comments on commit 8482efb

Please sign in to comment.