forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDA graphs] Private mempools for CUDA graphs (pytorch#51436)
Summary: Implements pytorch#51075 (comment) and additions discussed offline with ezyang ngimel . (Calling it "simple" is charitable but it's not too bad). [High level strategy](https://github.com/pytorch/pytorch/pull/51436/files#diff-acc6337586bf9cdcf0a684380779300ec171897d05b8569bf439820dc8c93bd5R57-R82) The current design aggregates stats from private pools with the ordinary pools, which may or may not be what we want. Instead of adding PrivatePools as an internal feature of DeviceAllocator, I could inherit from DeviceAllocator (eg `DevicePrivateAllocator : public DeviceAllocator`) and create separate per-graph instances of the inherited class. I'm not sure if that would be better. Graph bindings in Python are almost unchanged from pytorch#48875: ```python # Same bindings as 48875, but now implicitly grabs a private mempool graph1.capture_begin() graph1.capture_end() # pool=... is new. It hints that allocations during graph2's capture may share graph1's mempool graph2.capture_begin(pool=graph1.pool()) graph2.capture_end() # graph3 also implicitly creates its own mempool graph3.capture_begin() graph3.capture_end() ``` Test plan (other suggestions appreciated): - [x] Stop maintaining manual references for all the tensors in my existing graphs+RNG tests. If private pools somehow give bad allocations, they should start failing intermittently. They run eager ops and eager allocations mixed with graph replays, so they may expose if eager ops and replays corrupt each other. - [x] `test_graph_two_successive`: Capture successive graphs, with the second graph using the first graph's result. Try with and without sharing a pool. Check results, also check memory stats to confirm sharing a pool saves memory. - [x] `test_graph_concurrent_replay`: Capture some graphs in separate private pools, replay them concurrently in different streams, check the results to make sure they don't corrupt each other's memory. Capture some graphs with a shared pool, replay them concurrently in different streams, check results, confirm they DO corrupt each other's memory. - [x] `test_graph_three_successive`: A three-graph case, checking the safe and unsafe replay patterns in [Restrictions of the Strawman API](pytorch#51075)). - [x] `test_graph_memory_stats_and_use_result_after_destroy_graph`: Comprehensively check torch.cuda.memory_stats() changes that result from graph capture and delete. Check that a tensor ref created during capture and held after graph delete stays valid until the tensor itself is deleted. Pull Request resolved: pytorch#51436 Reviewed By: mruberry Differential Revision: D26993790 Pulled By: ngimel fbshipit-source-id: a992eaee1b8c23628e7b388a5a3c26e0f80e54da
- Loading branch information
1 parent
804f3f9
commit 90dfdef
Showing
12 changed files
with
785 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.