diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 3abbe2f868990..608f71fe7edd1 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +115,15 @@ int64_t CUDAHooks::current_device() const { return -1; } +bool CUDAHooks::hasPrimaryContext(int64_t device_index) const { + TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), + "hasPrimaryContext expects valid device index, but got device_index=", device_index); + unsigned int ctx_flags; + int ctx_is_active; + AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active)); + return ctx_is_active == 1; +} + Allocator* CUDAHooks::getPinnedMemoryAllocator() const { return at::cuda::getPinnedMemoryAllocator(); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index cf2ece1b2d2ef..4d2d0c6812b72 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -18,6 +18,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCuDNN() const override; const at::cuda::NVRTC& nvrtc() const override; int64_t current_device() const override; + bool hasPrimaryContext(int64_t device_index) const override; Allocator* getPinnedMemoryAllocator() const override; bool compiledWithCuDNN() const override; bool compiledWithMIOpen() const override; diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 118e40e672092..56a4b6893f15d 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -57,7 +57,8 @@ namespace at { namespace cuda { // // ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both // NVRTC and driver APIs. While the former is not yet suppoted for HIP, the -// later is supported and needed. +// later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext() +// used by tensor.pin_memory()). // // The macro below strips out certain unsupported operations on HIP from the full // list above. diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index af22c554433de..da24b5cc0cc4b 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -60,15 +60,15 @@ struct CAFFE2_API CUDAHooksInterface { // Initialize THCState and, transitively, the CUDA state virtual std::unique_ptr initCUDA() const { - AT_ERROR("Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); } virtual Generator* getDefaultCUDAGenerator(DeviceIndex device_index = -1) const { - AT_ERROR("Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP); } virtual Device getDeviceFromPtr(void* data) const { - AT_ERROR("Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); } virtual bool hasCUDA() const { @@ -84,15 +84,19 @@ struct CAFFE2_API CUDAHooksInterface { } virtual const at::cuda::NVRTC& nvrtc() const { - AT_ERROR("NVRTC requires CUDA. ", CUDA_HELP); + TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); } virtual int64_t current_device() const { return -1; } + virtual bool hasPrimaryContext(int64_t device_index) const { + TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP); + } + virtual Allocator* getPinnedMemoryAllocator() const { - AT_ERROR("Pinned memory requires CUDA. ", CUDA_HELP); + TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP); } virtual bool compiledWithCuDNN() const { @@ -112,32 +116,32 @@ struct CAFFE2_API CUDAHooksInterface { } virtual long versionCuDNN() const { - AT_ERROR("Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); } virtual std::string showConfig() const { - AT_ERROR("Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP); } virtual double batchnormMinEpsilonCuDNN() const { - AT_ERROR( + TORCH_CHECK(false, "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP); } virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const { - AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } virtual void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const { - AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } virtual int64_t cuFFTGetPlanCacheSize(int64_t device_index) const { - AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } virtual void cuFFTClearPlanCache(int64_t device_index) const { - AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); } virtual int getNumGPUs() const { diff --git a/aten/src/ATen/native/Memory.cpp b/aten/src/ATen/native/Memory.cpp index b2dd252df6716..aedae3942659e 100644 --- a/aten/src/ATen/native/Memory.cpp +++ b/aten/src/ATen/native/Memory.cpp @@ -1,10 +1,10 @@ #include #include #include +#include #include #include #include -#include namespace at { namespace native { diff --git a/aten/src/THC/THCCachingHostAllocator.cpp b/aten/src/THC/THCCachingHostAllocator.cpp index 6f0782103b55e..db3f511758398 100644 --- a/aten/src/THC/THCCachingHostAllocator.cpp +++ b/aten/src/THC/THCCachingHostAllocator.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include @@ -40,6 +42,24 @@ static bool BlockComparator(const BlockSize& a, const BlockSize& b) return (uintptr_t)a.ptr < (uintptr_t)b.ptr; } +static int64_t inline get_device_index_with_primary_context() { + const auto& cuda_hooks = at::detail::getCUDAHooks(); + // check current device first + int64_t current_device_index = cuda_hooks.current_device(); + if (current_device_index >= 0) { + if (cuda_hooks.hasPrimaryContext(current_device_index)) { + return current_device_index; + } + } + for (int64_t device_index = 0; device_index < cuda_hooks.getNumGPUs(); device_index++) { + if (device_index == current_device_index) continue; + if (cuda_hooks.hasPrimaryContext(device_index)) { + return device_index; + } + } + return -1; +} + struct HostAllocator { typedef bool (*Comparison)(const BlockSize&, const BlockSize&); @@ -80,6 +100,17 @@ struct HostAllocator return cudaSuccess; } + // Pinned memory pointers allocated by any device can be directly used by any + // other device, regardless of the current device at the time of allocation, + // since we assume unified addressing. + // So we grab any existing primary context, if available. + // See pytorch/pytorch#21081. + at::OptionalDeviceGuard device_guard; + auto primary_ctx_device_index = get_device_index_with_primary_context(); + if (primary_ctx_device_index >= 0) { + device_guard.reset_device(at::Device(at::DeviceType::CUDA, primary_ctx_device_index)); + } + // note that cudaHostAlloc may not touch pointer if size is 0 *ptr = 0; diff --git a/c10/core/DeviceGuard.h b/c10/core/DeviceGuard.h index 32cc8e5500cb0..852d6366ebd97 100644 --- a/c10/core/DeviceGuard.h +++ b/c10/core/DeviceGuard.h @@ -108,7 +108,7 @@ class DeviceGuard { * setDevice(1); * OptionalDeviceGuard g; * setDevice(2); - * g.set_device(3); // initializes! + * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! * * On destruction, g will reset device to 2, rather than 1. * @@ -118,7 +118,7 @@ class DeviceGuard { */ class OptionalDeviceGuard { public: - /// Create an uninitialized guard. Set the guard later using set_device. + /// Create an uninitialized guard. Set the guard later using reset_device. explicit OptionalDeviceGuard() : guard_() {} /// Initialize the guard, setting the current device to the passed Device. @@ -159,7 +159,7 @@ class OptionalDeviceGuard { } /// Returns the most recent device that was set using this device guard, - /// either from construction, or via set_device. + /// either from construction, or via reset_device. optional current_device() const { return guard_.current_device(); } diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index d7e7ec3f5c1ed..7aa45c5048d9a 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -2,7 +2,7 @@ #include "c10/util/Exception.h" #include "c10/macros/Macros.h" -#include "cuda.h" +#include // Note [CHECK macro] // ~~~~~~~~~~~~~~~~~~ diff --git a/test/common_utils.py b/test/common_utils.py index bd260f079ff39..2bb9a8a0342ba 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -18,6 +18,7 @@ import random import contextlib import socket +import subprocess import time from collections import OrderedDict from contextlib import contextmanager @@ -46,9 +47,12 @@ parser = argparse.ArgumentParser(add_help=False) +parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--accept', action='store_true') args, remaining = parser.parse_known_args() +TEST_IN_SUBPROCESS = args.subprocess SEED = args.seed if not expecttest.ACCEPT: expecttest.ACCEPT = args.accept @@ -56,8 +60,61 @@ torch.manual_seed(SEED) +def shell(command, cwd=None): + sys.stdout.flush() + sys.stderr.flush() + # The following cool snippet is copied from Py3 core library subprocess.call + # only the with + # 1. `except KeyboardInterrupt` block added for SIGINT handling. + # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do + # `p.wait()` in a `final` block for the code to be portable. + # + # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 + assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens" + p = subprocess.Popen(command, universal_newlines=True, cwd=cwd) + try: + return p.wait() + except KeyboardInterrupt: + # Give `p` a chance to handle KeyboardInterrupt. Without this, + # `pytest` can't print errors it collected so far upon KeyboardInterrupt. + exit_status = p.wait(timeout=5) + if exit_status is not None: + return exit_status + else: + p.kill() + raise + except: # noqa E722, copied from python core library + p.kill() + raise + finally: + # Always call p.wait() to ensure exit + p.wait() + + def run_tests(argv=UNITTEST_ARGS): - unittest.main(argv=argv) + if TEST_IN_SUBPROCESS: + suite = unittest.TestLoader().loadTestsFromModule(__main__) + test_cases = [] + + def add_to_test_cases(suite_or_case): + if isinstance(suite_or_case, unittest.TestCase): + test_cases.append(suite_or_case) + else: + for element in suite_or_case: + add_to_test_cases(element) + + add_to_test_cases(suite) + failed_tests = [] + for case in test_cases: + test_case_full_name = case.id().split('.', 1)[1] + exitcode = shell([sys.executable] + argv + [test_case_full_name]) + if exitcode != 0: + failed_tests.append(test_case_full_name) + + assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format( + len(failed_tests), '\n\t'.join(failed_tests)) + else: + unittest.main(argv=argv) PY3 = sys.version_info > (3, 0) PY34 = sys.version_info >= (3, 4) diff --git a/test/run_test.py b/test/run_test.py index 6f381c01e4240..8f406c40e1fbc 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -14,7 +14,7 @@ import torch import torch._six from torch.utils import cpp_extension -from common_utils import TEST_WITH_ROCM +from common_utils import TEST_WITH_ROCM, shell import torch.distributed as dist TESTS = [ @@ -98,49 +98,22 @@ def print_to_stderr(message): print(message, file=sys.stderr) -def shell(command, cwd=None): - sys.stdout.flush() - sys.stderr.flush() - # The folloing cool snippet is copied from Py3 core library subprocess.call - # only the with - # 1. `except KeyboardInterrupt` block added for SIGINT handling. - # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do - # `p.wait()` in a `final` block for the code to be portable. - # - # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 - assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens" - p = subprocess.Popen(command, universal_newlines=True, cwd=cwd) - try: - return p.wait() - except KeyboardInterrupt: - # Give `p` a chance to handle KeyboardInterrupt. Without this, - # `pytest` can't print errors it collected so far upon KeyboardInterrupt. - exit_status = p.wait(timeout=5) - if exit_status is not None: - return exit_status - else: - p.kill() - raise - except: # noqa E722, copied from python core library - p.kill() - raise - finally: - # Always call p.wait() to ensure exit - p.wait() - - -def run_test(executable, test_module, test_directory, options): +def run_test(executable, test_module, test_directory, options, *extra_unittest_args): unittest_args = options.additional_unittest_args if options.verbose: unittest_args.append('--verbose') # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. - argv = [test_module + '.py'] + unittest_args + argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args) command = executable + argv return shell(command, test_directory) +def test_cuda_primary_ctx(executable, test_module, test_directory, options): + return run_test(executable, test_module, test_directory, options, '--subprocess') + + def test_cpp_extensions(executable, test_module, test_directory, options): try: cpp_extension.verify_ninja_availability() @@ -226,6 +199,7 @@ def test_distributed(executable, test_module, test_directory, options): CUSTOM_HANDLERS = { + 'cuda_primary_ctx': test_cuda_primary_ctx, 'cpp_extensions': test_cpp_extensions, 'distributed': test_distributed, } diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index db2023d8702a8..001288b6c18ae 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -1,9 +1,6 @@ -import ctypes import torch from common_utils import TestCase, run_tests, skipIfRocm import unittest -import glob -import os # NOTE: this needs to be run in a brand new process @@ -19,49 +16,85 @@ TestCase = object # noqa: F811 -_caffe2_nvrtc = None +class TestCudaPrimaryCtx(TestCase): + CTX_ALREADY_CREATED_ERR_MSG = ( + "Tests defined in test_cuda_primary_ctx.py must be run in a process " + "where CUDA contexts are never created. Use either run_test.py or add " + "--subprocess to run each test in a different subprocess.") + + @skipIfRocm + def setUp(self): + for device in range(torch.cuda.device_count()): + # Ensure context has not been created beforehand + self.assertFalse(torch._C._cuda_hasPrimaryContext(device), TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG) + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") + def test_str_repr(self): + x = torch.randn(1, device='cuda:1') -def get_is_primary_context_created(device): - flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint)) - active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) - global _caffe2_nvrtc - if _caffe2_nvrtc is None: - path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0] - _caffe2_nvrtc = ctypes.cdll.LoadLibrary(path) - result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) - assert result == 0, 'cuDevicePrimaryCtxGetState failed' - return bool(active[0]) + # We should have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) + str(x) + repr(x) + + # We should still have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) -class TestCudaPrimaryCtx(TestCase): @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @skipIfRocm - def test_cuda_primary_ctx(self): - # Ensure context has not been created beforehand - self.assertFalse(get_is_primary_context_created(0)) - self.assertFalse(get_is_primary_context_created(1)) + def test_copy(self): + x = torch.randn(1, device='cuda:1') + # We should have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) + + y = torch.randn(1, device='cpu') + y.copy_(x) + + # We should still have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) + + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") + def test_pin_memory(self): x = torch.randn(1, device='cuda:1') # We should have only created context on 'cuda:1' - self.assertFalse(get_is_primary_context_created(0)) - self.assertTrue(get_is_primary_context_created(1)) + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) - print(x) + x = torch.randn(3, device='cpu').pin_memory() # We should still have only created context on 'cuda:1' - self.assertFalse(get_is_primary_context_created(0)) - self.assertTrue(get_is_primary_context_created(1)) + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) - y = torch.randn(1, device='cpu') - y.copy_(x) + x = torch.randn(3, device='cpu', pin_memory=True) + + # We should still have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) + + x = torch.zeros(3, device='cpu', pin_memory=True) # We should still have only created context on 'cuda:1' - self.assertFalse(get_is_primary_context_created(0)) - self.assertTrue(get_is_primary_context_created(1)) + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) - # DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS + x = torch.empty(3, device='cpu', pin_memory=True) + + # We should still have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) + + x = x.pin_memory() + + # We should still have only created context on 'cuda:1' + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + self.assertTrue(torch._C._cuda_hasPrimaryContext(1)) if __name__ == '__main__': run_tests() diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 20eb2117c5507..660e56ac3a515 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -202,6 +202,19 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module) Py_RETURN_NONE; } +PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg) +{ + HANDLE_TH_ERRORS + THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to has_primary_context"); + int64_t device_index = static_cast(THPUtils_unpackLong(arg)); + if (at::detail::getCUDAHooks().hasPrimaryContext(device_index)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + PyObject * THCPModule_emptyCache(PyObject *_unused) { HANDLE_TH_ERRORS @@ -383,6 +396,7 @@ static struct PyMethodDef _THCPModule_methods[] = { {"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, nullptr}, {"_cuda_getDriverVersion", (PyCFunction)THCPModule_getDriverVersion, METH_NOARGS, nullptr}, {"_cuda_getCompiledVersion", (PyCFunction)THCPModule_getCompiledVersion, METH_NOARGS, nullptr}, + {"_cuda_hasPrimaryContext", (PyCFunction) THCPModule_hasPrimaryContext, METH_O, nullptr}, {"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, nullptr}, {"_cuda_memoryAllocated", (PyCFunction) THCPModule_memoryAllocated, METH_O, nullptr}, {"_cuda_maxMemoryAllocated", (PyCFunction) THCPModule_maxMemoryAllocated, METH_O, nullptr},