Skip to content

Commit

Permalink
[CUDA graphs] Private mempools for CUDA graphs (pytorch#51436)
Browse files Browse the repository at this point in the history
Summary:
Implements pytorch#51075 (comment) and additions discussed offline with ezyang ngimel . (Calling it "simple" is charitable but it's not too bad).

[High level strategy](https://github.com/pytorch/pytorch/pull/51436/files#diff-acc6337586bf9cdcf0a684380779300ec171897d05b8569bf439820dc8c93bd5R57-R82)

The current design aggregates stats from private pools with the ordinary pools, which may or may not be what we want.

Instead of adding PrivatePools as an internal feature of DeviceAllocator, I could inherit from DeviceAllocator (eg `DevicePrivateAllocator : public DeviceAllocator`) and create separate per-graph instances of the inherited class. I'm not sure if that would be better.

Graph bindings in Python are almost unchanged from pytorch#48875:
```python
# Same bindings as 48875, but now implicitly grabs a private mempool
graph1.capture_begin()
graph1.capture_end()

# pool=... is new.  It hints that allocations during graph2's capture may share graph1's mempool
graph2.capture_begin(pool=graph1.pool())
graph2.capture_end()

# graph3 also implicitly creates its own mempool
graph3.capture_begin()
graph3.capture_end()
```

Test plan (other suggestions appreciated):

- [x] Stop maintaining manual references for all the tensors in my existing graphs+RNG tests. If private pools somehow give bad allocations, they should start failing intermittently. They run eager ops and eager allocations mixed with graph replays, so they may expose if eager ops and replays corrupt each other.
- [x] `test_graph_two_successive`: Capture successive graphs, with the second graph using the first graph's result. Try with and without sharing a pool. Check results, also check memory stats to confirm sharing a pool saves memory.
- [x] `test_graph_concurrent_replay`: Capture some graphs in separate private pools, replay them concurrently in different streams, check the results to make sure they don't corrupt each other's memory. Capture some graphs with a shared pool, replay them concurrently in different streams, check results, confirm they DO corrupt each other's memory.
- [x] `test_graph_three_successive`: A three-graph case, checking the safe and unsafe replay patterns in [Restrictions of the Strawman API](pytorch#51075)).
- [x] `test_graph_memory_stats_and_use_result_after_destroy_graph`: Comprehensively check torch.cuda.memory_stats() changes that result from graph capture and delete. Check that a tensor ref created during capture and held after graph delete stays valid until the tensor itself is deleted.

Pull Request resolved: pytorch#51436

Reviewed By: mruberry

Differential Revision: D26993790

Pulled By: ngimel

fbshipit-source-id: a992eaee1b8c23628e7b388a5a3c26e0f80e54da
  • Loading branch information
mcarilli authored and facebook-github-bot committed Mar 12, 2021
1 parent 804f3f9 commit 90dfdef
Show file tree
Hide file tree
Showing 12 changed files with 785 additions and 141 deletions.
79 changes: 76 additions & 3 deletions aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
#include <ATen/cuda/Exceptions.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>

namespace at {
namespace cuda {

MempoolId_t graph_pool_handle() {
#if CUDA_VERSION >= 11000
// uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
static std::atomic<CaptureId_t> uuid{1};
// Sets just the second value, to distinguish it from MempoolId_ts created from
// cudaStreamGetCaptureInfo id_s in capture_begin.
return {0, uuid++};
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
return {0, 0};
#endif
}

/**
* Note [CUDA Graph Wrapper Class]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -22,6 +37,11 @@ namespace cuda {
* ops properly. Their graphs would yield invalid numerics on replay.
*/

/**
* Note [Interaction with CUDA graph capture] in CUDACachingAllocator.cpp
* describes memory management for captures.
*/

CUDAGraph::CUDAGraph()
// CUDAStreams may not be default-constructed.
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
Expand All @@ -30,7 +50,7 @@ CUDAGraph::CUDAGraph()
#endif
}

void CUDAGraph::capture_begin() {
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
#if CUDA_VERSION >= 11000
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
Expand Down Expand Up @@ -58,16 +78,47 @@ void CUDAGraph::capture_begin() {

capture_stream_ = stream;
capture_gen_ = gen;
capture_dev_ = c10::cuda::current_device();

// cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));

// Stashes the current graph's uuid.
// Stashes the current capture's uuid.
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id_));
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);

// Ensures uuid count starts at 1. 0 is reserved to mean "not set by cudaStreamGetCaptureInfo".
// (But how do we know GetCaptureInfo never sets id_ to 0? Because that's the current behavior,
// and I asked cuda devs to keep it that way, and they agreed.)
TORCH_INTERNAL_ASSERT(id_ > 0);
if (pool.first != 0 || pool.second != 0) {
// Either value being nonzero means the user supplied a pool to share.
// But only one should be nonzero.
// If pool was created by another graph's capture_begin, first should be nonzero.
// If pool was created by graph_pool_handle, second should be nonzero.
TORCH_INTERNAL_ASSERT(!(pool.first && pool.second));
mempool_id_ = pool;
} else {
// User did not ask us to share a mempool. Use our own id_ as our mempool_id_.
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
mempool_id_ = {id_, 0};
}

// When CUDACachingAllocator allocates while a capture is underway, it calls cudaStreamGetCaptureInfo
// to get the current stream's capture id, if any. Here we tell CUDACachingAllocator: if the stream
// has a capture id matching this graph's id_, use the private pool mempool_id_ identifies.
//
// There's a small chance of a bad allocation here if another thread launches a kernel on
// capture_stream_ between the call to cudaStreamBeginCapture above and the call to
// notifyCaptureBegin below.
// But I don't think we need to worry about it because that use case makes no sense:
// The user has no business launching kernels on capture_stream_ from another thread
// while calling capture_begin. They'll have no idea if their side thread's
// kernel will end up as part of the capture or not.
c10::cuda::CUDACachingAllocator::notifyCaptureBegin(capture_dev_, id_, mempool_id_);
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
Expand All @@ -80,6 +131,8 @@ void CUDAGraph::capture_end() {
TORCH_CHECK(stream == capture_stream_,
"Capture must end on the same stream it began on.");

c10::cuda::CUDACachingAllocator::notifyCaptureEnd(capture_dev_, id_);

AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
TORCH_CHECK(graph_ != NULL, "Invalid capture.");
has_graph_ = true;
Expand Down Expand Up @@ -149,6 +202,15 @@ void CUDAGraph::reset() {
//
// Calling reset() in the C++ destructor, with warnings instead of exceptions
// if calls fail, is the compromise we chose.
//
// If capture_begin, the capture, or capture_end failed at some point, this CUDAGraph, the generator,
// and the allocator could end up in all kinds of weird states depending where failure occurred.
// If the user catches the failure exception in a script, or is running in REPL or (god forbid)
// a Juptyer notebook, I don't see an easy way for reset() to gracefully fix all such possible error states.
if (has_graph_ || has_graph_exec_) {
// notifyCaptureDestroy may throw. How should we handle this?
c10::cuda::CUDACachingAllocator::notifyCaptureDestroy(capture_dev_, mempool_id_);
}
if (has_graph_) {
C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_));
}
Expand All @@ -160,6 +222,17 @@ void CUDAGraph::reset() {
#endif
}

// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
MempoolId_t CUDAGraph::pool() {
#if CUDA_VERSION >= 11000
TORCH_CHECK(has_graph_exec_,
"Called CUDAGraph::pool() without a preceding successful capture.");
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#endif
return mempool_id_;
}

CUDAGraph::~CUDAGraph() {
reset();
}
Expand Down
44 changes: 39 additions & 5 deletions aten/src/ATen/cuda/CUDAGraph.h
Original file line number Diff line number Diff line change
@@ -1,39 +1,73 @@
#pragma once

#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/CUDAGeneratorImpl.h>

namespace at {

struct CUDAGeneratorImpl;

namespace cuda {

// Standalone way to get a unique mempool id usable as a pool=... argument
// to CUDAGraph::capture_begin
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();

struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph();
~CUDAGraph();

void capture_begin();
void capture_begin(MempoolId_t pool={0, 0});
void capture_end();
void replay();
void reset();
MempoolId_t pool();

protected:
#if CUDA_VERSION >= 11000
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
#endif

// internal states for error checking
// internal states so reset() can do its best cleaning up
// Set to true in capture_end if cudaStreamEndCapture succeeded
// Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
// to create graph_exec_, then graph_ is deleted
bool has_graph_ = false;
// Set to true in capture_end if cudaGraphInstantiate succeeded
bool has_graph_exec_ = false;

// uuid, retrieved from Cuda
unsigned long long id_;
// uuid of this instance's current capture, retrieved from Cuda
CaptureId_t id_;

// uuid used to request a particular private mempool from CUDACachingAllocator.
// By default, this will be set to {id_, 0}.
//
// If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
// will be set to the other graph's mempool_id_, and therefore share a mempool with the
// other graph.
//
// If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
// it will share a mempool with any other captures that used "pool=handle".
//
// Sharing a mempool across graphs saves memory, and it's safe if you
// know you'll replay those graphs in the same order you captured them.
MempoolId_t mempool_id_;

// Stream on which capture began
at::cuda::CUDAStream capture_stream_;

// Default generator on device where capture began
at::CUDAGeneratorImpl* capture_gen_;

// Device where capture occurred. Right now, for simplicity, we require all ops
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
// captures if needed.
int capture_dev_;

// RNG state trackers
at::Tensor offset_extragraph_;
uint64_t wholegraph_increment_;
Expand Down
59 changes: 11 additions & 48 deletions aten/src/ATen/cuda/CUDAGraphsUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,30 @@
#include <ATen/cuda/detail/UnpackRaw.cuh>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAGuard.h>

// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
// This file adds utils used by aten only.

namespace at {
namespace cuda {

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// Protects against enum cudaStreamCaptureStatus implementation changes.
// Some compilers seem not to like static_assert without the messages.
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
"unexpected int(cudaStreamCaptureStatusNone) value");
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
"unexpected int(cudaStreamCaptureStatusActive) value");
static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
#endif

enum class CaptureStatus: int {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
#else
None = 0
#endif
};

inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
switch(status) {
case CaptureStatus::None:
os << "cudaStreamCaptureStatusNone";
break;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
case CaptureStatus::Active:
os << "cudaStreamCaptureStatusActive";
break;
case CaptureStatus::Invalidated:
os << "cudaStreamCaptureStatusInvalidated";
break;
#endif
default:
TORCH_INTERNAL_ASSERT(false,
"Unknown CUDA graph CaptureStatus",
int(status));
}
return os;
}
using CaptureId_t = c10::cuda::CaptureId_t;
using CaptureStatus = c10::cuda::CaptureStatus;

// Use this version where you don't want to create a CUDA context if none exists.
inline CaptureStatus currentStreamCaptureStatus() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// don't create a context if we don't have to
if (at::detail::getCUDAHooks().hasPrimaryContext(c10::cuda::current_device())) {
cudaStreamCaptureStatus is_capturing;
AT_CUDA_CHECK(cudaStreamIsCapturing(at::cuda::getCurrentCUDAStream(),
&is_capturing));
return CaptureStatus(is_capturing);
return c10::cuda::currentStreamCaptureStatusMayInitCtx();
} else {
return CaptureStatus::None;
}
#else
#else
return CaptureStatus::None;
#endif
#endif
}

inline void assertNotCapturing(std::string attempt) {
Expand Down
Loading

0 comments on commit 90dfdef

Please sign in to comment.