forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A memory pool implementation based on cnmem. Added cnmem license to L…
…ICENSE.
- Loading branch information
Showing
9 changed files
with
1,868 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#include "third_party/cnmem/cnmem.h" | ||
#include "caffe2/core/cuda_memorypool.h" | ||
|
||
namespace caffe2 { | ||
|
||
#define CNMEM_CHECK(condition) \ | ||
do { \ | ||
cnmemStatus_t error = condition; \ | ||
CHECK_EQ(error, CNMEM_STATUS_SUCCESS) << cnmemGetErrorString(error); \ | ||
} while (0) | ||
|
||
bool CudaMemoryPool::is_memory_pool_setup_ = false; | ||
bool CudaMemoryPool::memory_allocated_before_setup_ = false; | ||
vector<bool> CudaMemoryPool::memory_pool_available_for_device_(0); | ||
vector<cudaStream_t> CudaMemoryPool::per_device_streams_(0); | ||
|
||
bool CudaMemoryPool::InitializeMemoryPool( | ||
const vector<int>& device_ids, | ||
const float proportion_of_memory_to_reserve) { | ||
if (memory_allocated_before_setup_) { | ||
LOG(ERROR) << "There is cuda memory allocated before we initialize the " | ||
"memory pool. This should not happen: you should either " | ||
"use raw cudaMalloc and cudaFree and not initialize the " | ||
"pool at all, or initialize the pool before you allocate " | ||
"anything."; | ||
return false; | ||
} | ||
if (is_memory_pool_setup_) { | ||
LOG(ERROR) << "Memory pool is already set up. I cannot set up it twice."; | ||
return false; | ||
} | ||
|
||
// The actual initialization. | ||
int device_count; | ||
CUDA_CHECK(cudaGetDeviceCount(&device_count)); | ||
// Initialize the flags for the memory pool. | ||
memory_pool_available_for_device_.resize(device_count, false); | ||
per_device_streams_.resize(device_count, nullptr); | ||
// Push the current device so we can recover later. | ||
int initial_device; | ||
CUDA_CHECK(cudaGetDevice(&initial_device)); | ||
|
||
vector<cnmemDevice_t> cnmem_devs(device_ids.size()); | ||
for (int i = 0; i < device_ids.size(); ++i) { | ||
const int device_id = device_ids[i]; | ||
CHECK_GE(device_id, 0); | ||
CHECK_LT(device_id, device_count); | ||
// This ensures we do not specify the same device twice. | ||
CHECK(!memory_pool_available_for_device_[device_id]); | ||
CUDA_CHECK(cudaSetDevice(device_id)); | ||
size_t free_memory, used_memory; | ||
CUDA_CHECK(cudaMemGetInfo(&free_memory, &used_memory)); | ||
LOG(INFO) << "Reserving " << proportion_of_memory_to_reserve * 100 | ||
<< "percent of the free memory (total " << free_memory | ||
<< ") on device " << device_id; | ||
// Note: we create a dummy non-null stream for memory allocations, so that | ||
// any malloc can be called from any cuda stream, since caffe2 uses a lot of | ||
// non-default streams for computation. We will allocate all the reserved | ||
// memory to that non-null stream. | ||
cnmem_devs[i].device = device_id; | ||
cnmem_devs[i].size = size_t(proportion_of_memory_to_reserve * free_memory); | ||
CUDA_CHECK(cudaStreamCreate(&per_device_streams_[i])); | ||
cnmem_devs[i].numStreams = 1; | ||
cnmem_devs[i].streams = &per_device_streams_[i]; | ||
cnmem_devs[i].streamSizes = &cnmem_devs[i].size; | ||
memory_pool_available_for_device_[device_id] = true; | ||
} | ||
CNMEM_CHECK( | ||
cnmemInit(cnmem_devs.size(), cnmem_devs.data(), CNMEM_FLAGS_DEFAULT)); | ||
// After initialization, let's set back the device. | ||
CUDA_CHECK(cudaSetDevice(initial_device)); | ||
LOG(INFO) << "Set up memory pool."; | ||
is_memory_pool_setup_ = true; | ||
return true; | ||
} | ||
|
||
bool CudaMemoryPool::FinalizeMemoryPool() { | ||
// If it has not been set up yet, we have nothing to do. | ||
if (!is_memory_pool_setup_) { | ||
return true; | ||
} | ||
CNMEM_CHECK(cnmemFinalize()); | ||
for (int i = 0; i < per_device_streams_.size(); ++i) { | ||
if (per_device_streams_[i]) { | ||
CUDA_CHECK(cudaStreamDestroy(per_device_streams_[i])); | ||
} | ||
} | ||
// Reset all the static variables | ||
per_device_streams_.resize(0); | ||
memory_pool_available_for_device_.resize(0); | ||
memory_allocated_before_setup_ = false; | ||
is_memory_pool_setup_ = false; | ||
return true; | ||
} | ||
|
||
void* CudaMemoryPool::NewWithMemoryPool(size_t nbytes) { | ||
int device_id; | ||
CUDA_CHECK(cudaGetDevice(&device_id)); | ||
CHECK(memory_pool_available_for_device_[device_id]) | ||
<< "Trying to allocate on device " << device_id | ||
<< ", but memory pool is not initialized on that device."; | ||
void* ptr; | ||
CNMEM_CHECK(cnmemMalloc(&ptr, nbytes, per_device_streams_[device_id])); | ||
return ptr; | ||
} | ||
|
||
void CudaMemoryPool::DeleteWithMemoryPool(void* data) { | ||
cudaPointerAttributes attr; | ||
CUDA_CHECK(cudaPointerGetAttributes(&attr, data)); | ||
DCHECK_EQ(attr.memoryType, cudaMemoryTypeDevice); | ||
CHECK(memory_pool_available_for_device_[attr.device]) | ||
<< "Current pointer belongs to " << attr.device | ||
<< ", but memory pool is not initialized on that device. " | ||
<< "Was your pointer allocated using the memory pool?"; | ||
CNMEM_CHECK(cnmemFree(data, per_device_streams_[attr.device])); | ||
} | ||
|
||
} // namespace caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#ifndef CAFFE2_CORE_CUDA_MEMORYPOOL_H_ | ||
#define CAFFE2_CORE_CUDA_MEMORYPOOL_H_ | ||
|
||
#include <cstddef> | ||
|
||
#include "caffe2/core/common_gpu.h" | ||
#include "glog/logging.h" | ||
|
||
namespace caffe2 { | ||
|
||
class CudaMemoryPool { | ||
public: | ||
// Initializes the memory pool on the device ids, and pre-preserves the given | ||
// proportion of the currently free memory on the device. | ||
static bool InitializeMemoryPool( | ||
const vector<int>& device_ids, | ||
const float proportion_of_memory_to_reserve); | ||
|
||
// Finalizes the memory pool. This has to be called after all memory allocated | ||
// by the memory pool has been freed. | ||
static bool FinalizeMemoryPool(); | ||
|
||
static inline bool MemoryPoolInitialized() { return is_memory_pool_setup_; } | ||
static inline bool MemoryPoolAvailableForDevice(int device_id) { | ||
return (device_id < memory_pool_available_for_device_.size() && | ||
memory_pool_available_for_device_[device_id]); | ||
} | ||
|
||
static inline void* New(size_t nbytes) { | ||
if (is_memory_pool_setup_) { | ||
return NewWithMemoryPool(nbytes); | ||
} else { | ||
// If memory pool is not set up, use simple cudaMalloc. | ||
void* dev_ptr; | ||
CUDA_CHECK(cudaMalloc(&dev_ptr, nbytes)); | ||
memory_allocated_before_setup_ = true; | ||
return dev_ptr; | ||
} | ||
} | ||
|
||
static inline void Delete(void* data) { | ||
if (is_memory_pool_setup_) { | ||
DeleteWithMemoryPool(data); | ||
} else { | ||
// If memory pool is not set up, use simple cudaFree. | ||
cudaError_t error = cudaFree(data); | ||
// For some reason, in Python runtime we sometimes delete a data pointer | ||
// after the cuda runtime exits - this is odd but is probably caused by | ||
// a static workspace that pycaffe2 uses, and the destruction got entangled | ||
// in some race condition. Anyway, since cuda runtime is exiting anyway, we | ||
// will not need to worry about memory leak, so we basically ignore it. | ||
// This is definitely not ideal but works for now. | ||
if (error != cudaSuccess && error != cudaErrorCudartUnloading) { | ||
LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " | ||
<< cudaGetErrorString(error); | ||
} | ||
} | ||
} | ||
|
||
private: | ||
// CudaMemoryPool is a singleton, so it should not be instantiated. | ||
CudaMemoryPool() {}; | ||
static void* NewWithMemoryPool(size_t nbytes); | ||
static void DeleteWithMemoryPool(void* data); | ||
|
||
static bool is_memory_pool_setup_; | ||
static bool memory_allocated_before_setup_; | ||
static vector<bool> memory_pool_available_for_device_; | ||
static vector<cudaStream_t> per_device_streams_; | ||
}; | ||
|
||
} // namespace caffe2 | ||
|
||
#endif // CAFFE2_CORE_CUDA_MEMORYPOOL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#include "caffe2/core/cuda_memorypool.h" | ||
#include "caffe2/core/context_gpu.h" | ||
#include "gtest/gtest.h" | ||
#include "glog/logging.h" | ||
|
||
namespace caffe2 { | ||
|
||
struct UseMemoryPool { static const bool value = true; }; | ||
struct NotUseMemoryPool { static const bool value = false; }; | ||
|
||
template <class UsePoolOrNot> | ||
class MemoryPoolTest : public ::testing::Test { | ||
protected: | ||
MemoryPoolTest() : device_count_(0) {} | ||
// virtual void SetUp() will be called before each test is run. You | ||
// should define it if you need to initialize the varaibles. | ||
// Otherwise, this can be skipped. | ||
void SetUp() override { | ||
int device_count_; | ||
CUDA_CHECK(cudaGetDeviceCount(&device_count_)); | ||
// If we test with the memory pool, initialize the memory pool. | ||
if (UsePoolOrNot::value) { | ||
vector<int> device_ids(device_count_); | ||
for (int i = 0; i < device_count_; ++i) { | ||
device_ids[i] = i; | ||
} | ||
CHECK(CudaMemoryPool::InitializeMemoryPool(device_ids, 0.8)); | ||
} | ||
} | ||
|
||
void TearDown() override { | ||
if (UsePoolOrNot::value) { | ||
CHECK(CudaMemoryPool::FinalizeMemoryPool()); | ||
} | ||
} | ||
|
||
// Declares the variables your tests want to use. | ||
int device_count_; | ||
}; | ||
|
||
typedef ::testing::Types<UseMemoryPool, NotUseMemoryPool> MemoryPoolTestTypes; | ||
TYPED_TEST_CASE(MemoryPoolTest, MemoryPoolTestTypes); | ||
|
||
// This just tests that setup and teardown works. | ||
TYPED_TEST(MemoryPoolTest, InitializeAndFinalizeWorks) { | ||
EXPECT_TRUE(true); | ||
} | ||
|
||
TYPED_TEST(MemoryPoolTest, AllocateAndDeallocate) { | ||
const int nbytes = 1048576; | ||
for (int i = 0; i < this->device_count_; ++i) { | ||
LOG(INFO) << "Device " << i << " of " << this->device_count_; | ||
CUDA_CHECK(cudaSetDevice(i)); | ||
void* allocated = CUDAContext::New(nbytes); | ||
EXPECT_NE(allocated, nullptr); | ||
cudaPointerAttributes attr; | ||
CUDA_CHECK(cudaPointerGetAttributes(&attr, allocated)); | ||
EXPECT_EQ(attr.memoryType, cudaMemoryTypeDevice); | ||
EXPECT_EQ(attr.device, i); | ||
CUDAContext::Delete(allocated); | ||
} | ||
} | ||
|
||
} // namespace caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
cuda_library( | ||
name = "cnmem", | ||
srcs = [ | ||
"cnmem.cpp", | ||
], | ||
hdrs = [ | ||
"cnmem.h", | ||
], | ||
) |
Oops, something went wrong.