From 1ddf529e199dcf7dd7dff7ceeb73181e81ba948a Mon Sep 17 00:00:00 2001 From: Chris Choy Date: Tue, 19 May 2020 18:11:23 -0700 Subject: [PATCH] optional memory manager support Squashed commit of the following: commit 67ade204ea59b3c4488c3361093ee424dd22ddd8 Author: Chris Choy Date: Wed May 20 01:00:57 2020 +0000 error fix gpu memman commit 0ffe8ac40ccf5264fe90c37e1daef74520cec8cf Author: Chris Choy Date: Sun May 17 13:09:10 2020 -0700 memory manager control commit dc479ce7a7ed6112db8ff4ab9476f22dcfc0f9d2 Author: Chris Choy Date: Sun May 17 01:05:49 2020 -0700 memory manager backend --- CHANGELOG.md | 2 + MinkowskiEngine/MinkowskiConvolution.py | 10 +-- MinkowskiEngine/MinkowskiCoords.py | 33 +++++++--- MinkowskiEngine/SparseTensor.py | 6 +- MinkowskiEngine/__init__.py | 4 +- pybind/minkowski.cpp | 41 +++++++----- src/coords_manager.cu | 4 +- src/coords_manager.hpp | 15 +++-- src/gpu_memory_manager.hpp | 83 +++++++++++++++++++++---- tests/coords.py | 10 ++- 10 files changed, 157 insertions(+), 51 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b32cf67..2a127cb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ - SyncBatchNorm error fix (Issue #129) - Sparse Tensor `dense()` doc update (Issue #126) - Installation arguments `--cuda_home=`, `--force_cuda`, `--blas_include_dirs=`, and '--blas_library_dirs=`. (Issue #135) +- SparseTensor query by coordinates `features_at_coords` (Issue #137) +- Memory manager control. CUDA | Pytorch memory manager for cuda malloc ## [0.4.2] - 2020-03-13 diff --git a/MinkowskiEngine/MinkowskiConvolution.py b/MinkowskiEngine/MinkowskiConvolution.py index 29e75633..a1b95474 100644 --- a/MinkowskiEngine/MinkowskiConvolution.py +++ b/MinkowskiEngine/MinkowskiConvolution.py @@ -275,10 +275,10 @@ def forward(self, # Get a new coords key or extract one from the coords out_coords_key = _get_coords_key(input, coords) outfeat = conv.apply(input.F, self.kernel, input.tensor_stride, - self.stride, self.kernel_size, - self.dilation, self.region_type_, - self.region_offset_, input.coords_key, - out_coords_key, input.coords_man) + self.stride, self.kernel_size, self.dilation, + self.region_type_, self.region_offset_, + input.coords_key, out_coords_key, + input.coords_man) if self.has_bias: outfeat += self.bias @@ -498,4 +498,4 @@ def forward(self, outfeat += self.bias return SparseTensor( - outfeat, coords_key=out_coords_key, coords_manager=input.coords_man) \ No newline at end of file + outfeat, coords_key=out_coords_key, coords_manager=input.coords_man) diff --git a/MinkowskiEngine/MinkowskiCoords.py b/MinkowskiEngine/MinkowskiCoords.py index f15705de..1ebeb13b 100644 --- a/MinkowskiEngine/MinkowskiCoords.py +++ b/MinkowskiEngine/MinkowskiCoords.py @@ -29,11 +29,21 @@ import torch from Common import convert_to_int_list, convert_to_int_tensor, prep_args import MinkowskiEngineBackend as MEB +from MinkowskiEngineBackend import MemoryManagerBackend CPU_COUNT = os.cpu_count() if 'OMP_NUM_THREADS' in os.environ: CPU_COUNT = int(os.environ['OMP_NUM_THREADS']) +_memory_manager_backend = MemoryManagerBackend.CUDA + + +def set_memory_manager_backend(backend: MemoryManagerBackend): + assert isinstance(backend, MemoryManagerBackend), \ + f"Input must be an instance of MemoryManagerBackend not {backend}" + global _memory_manager_backend + _memory_manager_backend = backend + class CoordsKey(): @@ -68,14 +78,19 @@ def __eq__(self, other): class CoordsManager(): - def __init__(self, num_threads: int = -1, D: int = -1): + def __init__(self, + num_threads: int = -1, + memory_manager_backend: MemoryManagerBackend = None, + D: int = -1): if D < 1: raise ValueError(f"Invalid dimension {D}") self.D = D - CPPCoordsManager = MEB.CoordsManager if num_threads < 0: - num_threads = CPU_COUNT - coords_man = CPPCoordsManager(num_threads) + num_threads = min(CPU_COUNT, 20) + if memory_manager_backend is None: + global _memory_manager_backend + memory_manager_backend = _memory_manager_backend + coords_man = MEB.CoordsManager(num_threads, memory_manager_backend) self.CPPCoordsManager = coords_man def initialize(self, @@ -88,10 +103,9 @@ def initialize(self, assert isinstance(coords_key, CoordsKey) unique_index = torch.LongTensor() inverse_mapping = torch.LongTensor() - self.CPPCoordsManager.initializeCoords(coords, unique_index, inverse_mapping, - coords_key.CPPCoordsKey, - force_creation, force_remap, - allow_duplicate_coords, return_inverse) + self.CPPCoordsManager.initializeCoords( + coords, unique_index, inverse_mapping, coords_key.CPPCoordsKey, + force_creation, force_remap, allow_duplicate_coords, return_inverse) return unique_index, inverse_mapping def create_coords_key(self, @@ -242,7 +256,8 @@ def get_kernel_map(self, """ # region type 1 iteration with kernel_size 1 is invalid if isinstance(kernel_size, torch.Tensor): - assert (kernel_size > 0).all(), f"Invalid kernel size: {kernel_size}" + assert (kernel_size > + 0).all(), f"Invalid kernel size: {kernel_size}" if (kernel_size == 1).all() == 1: region_type = 0 elif isinstance(kernel_size, int): diff --git a/MinkowskiEngine/SparseTensor.py b/MinkowskiEngine/SparseTensor.py index 485bdd6d..0b9d6a46 100644 --- a/MinkowskiEngine/SparseTensor.py +++ b/MinkowskiEngine/SparseTensor.py @@ -33,6 +33,7 @@ from Common import convert_to_int_list from MinkowskiCoords import CoordsKey, CoordsManager import MinkowskiEngineBackend as MEB +from MinkowskiEngineBackend import MemoryManagerBackend class SparseTensorOperationMode(Enum): @@ -146,6 +147,7 @@ def __init__( force_creation=False, allow_duplicate_coords=False, quantization_mode=SparseTensorQuantizationMode.RANDOM_SUBSAMPLE, + memory_manager_backend: MemoryManagerBackend = None, tensor_stride=1): r""" @@ -251,7 +253,9 @@ def __init__( global _sparse_tensor_operation_mode, _global_coords_man if _sparse_tensor_operation_mode == SparseTensorOperationMode.SHARE_COORDS_MANAGER: if _global_coords_man is None: - _global_coords_man = CoordsManager(D=coords.size(1) - 1) + _global_coords_man = CoordsManager( + memory_manager_backend=memory_manager_backend, + D=coords.size(1) - 1) coords_manager = _global_coords_man else: assert coords is not None, "Initial coordinates must be given" diff --git a/MinkowskiEngine/__init__.py b/MinkowskiEngine/__init__.py index 2360b408..389b89c8 100644 --- a/MinkowskiEngine/__init__.py +++ b/MinkowskiEngine/__init__.py @@ -32,13 +32,15 @@ # Must be imported first to load all required shared libs import torch +from MinkowskiEngineBackend import MemoryManagerBackend + from SparseTensor import SparseTensor, SparseTensorOperationMode, SparseTensorQuantizationMode, \ set_sparse_tensor_operation_mode, sparse_tensor_operation_mode, clear_global_coords_man from Common import RegionType, convert_to_int_tensor, convert_region_type, \ MinkowskiModuleBase, KernelGenerator, GlobalPoolingMode -from MinkowskiCoords import CoordsKey, CoordsManager +from MinkowskiCoords import CoordsKey, CoordsManager, set_memory_manager_backend from MinkowskiConvolution import MinkowskiConvolutionFunction, MinkowskiConvolution, \ MinkowskiConvolutionTransposeFunction, MinkowskiConvolutionTranspose diff --git a/pybind/minkowski.cpp b/pybind/minkowski.cpp index 7095c452..9c2ec7b1 100644 --- a/pybind/minkowski.cpp +++ b/pybind/minkowski.cpp @@ -189,11 +189,11 @@ void instantiate_func(py::module &m, const std::string &dtypestr) { #endif } -template -void instantiate_coordsman(py::module &m) { +template void instantiate_coordsman(py::module &m) { std::string coords_name = std::string("CoordsManager"); py::class_>(m, coords_name.c_str()) .def(py::init()) + .def(py::init()) .def("existsCoordsKey", (bool (mink::CoordsManager::*)(py::object) const) & mink::CoordsManager::existsCoordsKey) @@ -204,40 +204,44 @@ void instantiate_coordsman(py::module &m) { #endif .def("getCoordsMap", &mink::CoordsManager::getCoordsMap) .def("getUnionMap", &mink::CoordsManager::getUnionMap) - .def("getCoordsSize", (int (mink::CoordsManager::*)(py::object) const) & - mink::CoordsManager::getCoordsSize) + .def("getCoordsSize", + (int (mink::CoordsManager::*)(py::object) const) & + mink::CoordsManager::getCoordsSize) .def("getCoords", &mink::CoordsManager::getCoords) .def("getBatchSize", &mink::CoordsManager::getBatchSize) .def("getBatchIndices", &mink::CoordsManager::getBatchIndices) .def("getRowIndicesAtBatchIndex", &mink::CoordsManager::getRowIndicesAtBatchIndex) - .def("getRowIndicesPerBatch", &mink::CoordsManager::getRowIndicesPerBatch) - .def("setOriginCoordsKey", &mink::CoordsManager::setOriginCoordsKey) + .def("getRowIndicesPerBatch", + &mink::CoordsManager::getRowIndicesPerBatch) + .def("setOriginCoordsKey", + &mink::CoordsManager::setOriginCoordsKey) .def("initializeCoords", - (uint64_t(mink::CoordsManager::*)(at::Tensor, at::Tensor, at::Tensor, - py::object, const bool, const bool, - const bool, const bool)) & + (uint64_t(mink::CoordsManager::*)( + at::Tensor, at::Tensor, at::Tensor, py::object, const bool, + const bool, const bool, const bool)) & mink::CoordsManager::initializeCoords, py::call_guard()) - .def("createStridedCoords", &mink::CoordsManager::createStridedCoords) + .def("createStridedCoords", + &mink::CoordsManager::createStridedCoords) .def("createTransposedStridedRegionCoords", &mink::CoordsManager::createTransposedStridedRegionCoords) - .def("createPrunedCoords", &mink::CoordsManager::createPrunedCoords) - .def("createOriginCoords", &mink::CoordsManager::createOriginCoords) + .def("createPrunedCoords", + &mink::CoordsManager::createPrunedCoords) + .def("createOriginCoords", + &mink::CoordsManager::createOriginCoords) .def("printDiagnostics", &mink::CoordsManager::printDiagnostics) .def("__repr__", [](const mink::CoordsManager &a) { return a.toString(); }); } -template -void instantiate(py::module &m) { +template void instantiate(py::module &m) { instantiate_coordsman(m); instantiate_func(m, std::string("f")); instantiate_func(m, std::string("d")); } -template -void bind_native(py::module &m) { +template void bind_native(py::module &m) { std::string name = std::string("CoordsKey"); py::class_(m, name.c_str()) .def(py::init<>()) @@ -260,6 +264,11 @@ void bind_native(py::module &m) { } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::enum_(m, "MemoryManagerBackend") + .value("CUDA", mink::MemoryManagerBackend::CUDA) + .value("PYTORCH", mink::MemoryManagerBackend::PYTORCH) + .export_values(); + bind_native(m); instantiate(m); } diff --git a/src/coords_manager.cu b/src/coords_manager.cu index fd26c3e0..9f0982cd 100644 --- a/src/coords_manager.cu +++ b/src/coords_manager.cu @@ -45,7 +45,7 @@ CoordsManager::copyInOutMapToGPU(const InOutMaps &map) { pInOutMaps d_map; const int n = getInOutMapsSize(map); - int *d_scr = (int *)gpu_memory_manager.gpuMalloc(n * sizeof(int)); + int *d_scr = (int *)gpu_memory_manager.get()->gpuMalloc(n * sizeof(int)); for (const auto &cmap : map) { // Copy (*p_in_maps)[k] to GPU @@ -166,7 +166,7 @@ vector> CoordsManager::getKernelMapGPU( torch::TensorOptions() .dtype(torch::kInt64) // .device(torch::kCUDA) - .device(torch::kCUDA, gpu_memory_manager.device_id) + .device(torch::kCUDA, gpu_memory_manager.get()->get_device_id()) .requires_grad(false); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/src/coords_manager.hpp b/src/coords_manager.hpp index bbea720a..3fe4fa89 100644 --- a/src/coords_manager.hpp +++ b/src/coords_manager.hpp @@ -122,11 +122,18 @@ template class CoordsManager { unordered_map, InOutMapKeyHash> in_maps; unordered_map, InOutMapKeyHash> out_maps; - CoordsManager(){}; + CoordsManager(){ + gpu_memory_manager = std::make_shared(); + }; CoordsManager(int num_threads) { omp_set_dynamic(0); omp_set_num_threads(num_threads); } + CoordsManager(int num_threads, MemoryManagerBackend backend) { + omp_set_dynamic(0); + omp_set_num_threads(num_threads); + gpu_memory_manager = std::make_shared(backend); + } ~CoordsManager() { clear(); } void printDiagnostics(py::object py_coords_key) const; @@ -263,7 +270,7 @@ template class CoordsManager { #ifndef CPU_ONLY // GPU memory manager - GPUMemoryManager gpu_memory_manager; + std::shared_ptr gpu_memory_manager; // Keep all in out maps throughout the lifecycle of the coords manager // @@ -294,10 +301,10 @@ template class CoordsManager { py::object py_out_coords_key); void *getScratchGPUMemory(size_t size) { - return gpu_memory_manager.tmp_data(size); + return gpu_memory_manager.get()->tmp_data(size); } - void clearScratchGPUMemory() { gpu_memory_manager.clear_tmp(); } + void clearScratchGPUMemory() { gpu_memory_manager.get()->clear_tmp(); } #endif // CPU_ONLY }; // coordsmanager diff --git a/src/gpu_memory_manager.hpp b/src/gpu_memory_manager.hpp index af632b07..94c6e508 100644 --- a/src/gpu_memory_manager.hpp +++ b/src/gpu_memory_manager.hpp @@ -1,10 +1,39 @@ +/* Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + * Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural + * Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part + * of the code. + */ #ifndef CPU_ONLY #ifndef GPU_MEMORY_MANAGER #define GPU_MEMORY_MANAGER +#include #include +#include +#include +#include + #include "gpu.cuh" #include "types.hpp" @@ -12,22 +41,39 @@ namespace minkowski { using std::vector; +enum MemoryManagerBackend { CUDA = 0, PYTORCH = 1 }; + class GPUMemoryManager { +private: int initial_size = 256; - -public: + MemoryManagerBackend backend; int device_id; +public: // A set of data that will be not be freed untill the class is destroyed. vector persist_vec_ptr; vector tmp_vec_ptr; // Memory manager simply allocates and free memory when done. - GPUMemoryManager() { CUDA_CHECK(cudaGetDevice(&device_id)); } - GPUMemoryManager(int size) : initial_size(size) { GPUMemoryManager(); } + GPUMemoryManager(MemoryManagerBackend backend_) : backend(backend_) { + CUDA_CHECK(cudaGetDevice(&device_id)); + // std::cout << "GPU set to " << device_id << "\n"; + } + GPUMemoryManager() : GPUMemoryManager(PYTORCH) {} // use pytorch by default ~GPUMemoryManager() { - for (auto p_buffer : persist_vec_ptr) { - cudaFree(p_buffer); + switch (backend) { + case CUDA: { + for (auto p_buffer : persist_vec_ptr) { + cudaFree(p_buffer); + } + break; + } + case PYTORCH: { + for (auto p_buffer : persist_vec_ptr) { + c10::cuda::CUDACachingAllocator::raw_delete(p_buffer); + } + break; + } } } @@ -40,9 +86,8 @@ class GPUMemoryManager { tmp_vec_ptr.clear(); } - void set_device() { - CUDA_CHECK(cudaSetDevice(device_id)); - } + void set_device() { CUDA_CHECK(cudaSetDevice(device_id)); } + int get_device_id() const { return device_id; } void *tmp_data(size_t size) { void *p_buffer = NULL; @@ -54,9 +99,23 @@ class GPUMemoryManager { void *gpuMalloc(size_t size) { void *p_buffer = NULL; - CUDA_CHECK(cudaSetDevice(device_id)); - CUDA_CHECK(cudaMalloc(&p_buffer, size)); - persist_vec_ptr.push_back(p_buffer); + switch (backend) { + case CUDA: { + // std::cout << "Malloc CUDA: " << device_id << std::endl; + CUDA_CHECK(cudaSetDevice(device_id)); + CUDA_CHECK(cudaMalloc(&p_buffer, size)); + persist_vec_ptr.push_back(p_buffer); + break; + } + case PYTORCH: { + // std::cout << "Malloc PYTORCH: " << device_id << std::endl; + CUDA_CHECK(cudaSetDevice(device_id)); + p_buffer = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream( + size, at::cuda::getCurrentCUDAStream()); + persist_vec_ptr.push_back(p_buffer); + break; + } + } return p_buffer; } }; diff --git a/tests/coords.py b/tests/coords.py index eb70beef..2e24ba64 100644 --- a/tests/coords.py +++ b/tests/coords.py @@ -26,7 +26,8 @@ import torch import numpy as np -from MinkowskiEngine import CoordsKey, CoordsManager +import MinkowskiEngine as ME +from MinkowskiEngine import CoordsKey, CoordsManager, MemoryManagerBackend from tests.common import data_loader @@ -167,6 +168,13 @@ def test_batch_size_initialize(self): self.assertTrue(cm.get_batch_size() == 2) + def test_memory_manager_backend(self): + CoordsManager(memory_manager_backend=MemoryManagerBackend.CUDA, D=2) + CoordsManager(memory_manager_backend=MemoryManagerBackend.PYTORCH, D=2) + + ME.set_memory_manager_backend(MemoryManagerBackend.PYTORCH) + CoordsManager(D=2) + if __name__ == '__main__': unittest.main()