Skip to content

Commit

Permalink
optional memory manager support
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit 67ade204ea59b3c4488c3361093ee424dd22ddd8
Author: Chris Choy <cchoy@nvidia.com>
Date:   Wed May 20 01:00:57 2020 +0000

    error fix gpu memman

commit 0ffe8ac40ccf5264fe90c37e1daef74520cec8cf
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sun May 17 13:09:10 2020 -0700

    memory manager control

commit dc479ce7a7ed6112db8ff4ab9476f22dcfc0f9d2
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sun May 17 01:05:49 2020 -0700

    memory manager backend
  • Loading branch information
chrischoy committed May 20, 2020
1 parent 6618442 commit 1ddf529
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 51 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
- SyncBatchNorm error fix (Issue #129)
- Sparse Tensor `dense()` doc update (Issue #126)
- Installation arguments `--cuda_home=<value>`, `--force_cuda`, `--blas_include_dirs=<comma_separated_values>`, and '--blas_library_dirs=<comma_separated_values>`. (Issue #135)
- SparseTensor query by coordinates `features_at_coords` (Issue #137)
- Memory manager control. CUDA | Pytorch memory manager for cuda malloc


## [0.4.2] - 2020-03-13
Expand Down
10 changes: 5 additions & 5 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ def forward(self,
# Get a new coords key or extract one from the coords
out_coords_key = _get_coords_key(input, coords)
outfeat = conv.apply(input.F, self.kernel, input.tensor_stride,
self.stride, self.kernel_size,
self.dilation, self.region_type_,
self.region_offset_, input.coords_key,
out_coords_key, input.coords_man)
self.stride, self.kernel_size, self.dilation,
self.region_type_, self.region_offset_,
input.coords_key, out_coords_key,
input.coords_man)
if self.has_bias:
outfeat += self.bias

Expand Down Expand Up @@ -498,4 +498,4 @@ def forward(self,
outfeat += self.bias

return SparseTensor(
outfeat, coords_key=out_coords_key, coords_manager=input.coords_man)
outfeat, coords_key=out_coords_key, coords_manager=input.coords_man)
33 changes: 24 additions & 9 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,21 @@
import torch
from Common import convert_to_int_list, convert_to_int_tensor, prep_args
import MinkowskiEngineBackend as MEB
from MinkowskiEngineBackend import MemoryManagerBackend

CPU_COUNT = os.cpu_count()
if 'OMP_NUM_THREADS' in os.environ:
CPU_COUNT = int(os.environ['OMP_NUM_THREADS'])

_memory_manager_backend = MemoryManagerBackend.CUDA


def set_memory_manager_backend(backend: MemoryManagerBackend):
assert isinstance(backend, MemoryManagerBackend), \
f"Input must be an instance of MemoryManagerBackend not {backend}"
global _memory_manager_backend
_memory_manager_backend = backend


class CoordsKey():

Expand Down Expand Up @@ -68,14 +78,19 @@ def __eq__(self, other):

class CoordsManager():

def __init__(self, num_threads: int = -1, D: int = -1):
def __init__(self,
num_threads: int = -1,
memory_manager_backend: MemoryManagerBackend = None,
D: int = -1):
if D < 1:
raise ValueError(f"Invalid dimension {D}")
self.D = D
CPPCoordsManager = MEB.CoordsManager
if num_threads < 0:
num_threads = CPU_COUNT
coords_man = CPPCoordsManager(num_threads)
num_threads = min(CPU_COUNT, 20)
if memory_manager_backend is None:
global _memory_manager_backend
memory_manager_backend = _memory_manager_backend
coords_man = MEB.CoordsManager(num_threads, memory_manager_backend)
self.CPPCoordsManager = coords_man

def initialize(self,
Expand All @@ -88,10 +103,9 @@ def initialize(self,
assert isinstance(coords_key, CoordsKey)
unique_index = torch.LongTensor()
inverse_mapping = torch.LongTensor()
self.CPPCoordsManager.initializeCoords(coords, unique_index, inverse_mapping,
coords_key.CPPCoordsKey,
force_creation, force_remap,
allow_duplicate_coords, return_inverse)
self.CPPCoordsManager.initializeCoords(
coords, unique_index, inverse_mapping, coords_key.CPPCoordsKey,
force_creation, force_remap, allow_duplicate_coords, return_inverse)
return unique_index, inverse_mapping

def create_coords_key(self,
Expand Down Expand Up @@ -242,7 +256,8 @@ def get_kernel_map(self,
"""
# region type 1 iteration with kernel_size 1 is invalid
if isinstance(kernel_size, torch.Tensor):
assert (kernel_size > 0).all(), f"Invalid kernel size: {kernel_size}"
assert (kernel_size >
0).all(), f"Invalid kernel size: {kernel_size}"
if (kernel_size == 1).all() == 1:
region_type = 0
elif isinstance(kernel_size, int):
Expand Down
6 changes: 5 additions & 1 deletion MinkowskiEngine/SparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from Common import convert_to_int_list
from MinkowskiCoords import CoordsKey, CoordsManager
import MinkowskiEngineBackend as MEB
from MinkowskiEngineBackend import MemoryManagerBackend


class SparseTensorOperationMode(Enum):
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
force_creation=False,
allow_duplicate_coords=False,
quantization_mode=SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
memory_manager_backend: MemoryManagerBackend = None,
tensor_stride=1):
r"""
Expand Down Expand Up @@ -251,7 +253,9 @@ def __init__(
global _sparse_tensor_operation_mode, _global_coords_man
if _sparse_tensor_operation_mode == SparseTensorOperationMode.SHARE_COORDS_MANAGER:
if _global_coords_man is None:
_global_coords_man = CoordsManager(D=coords.size(1) - 1)
_global_coords_man = CoordsManager(
memory_manager_backend=memory_manager_backend,
D=coords.size(1) - 1)
coords_manager = _global_coords_man
else:
assert coords is not None, "Initial coordinates must be given"
Expand Down
4 changes: 3 additions & 1 deletion MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
# Must be imported first to load all required shared libs
import torch

from MinkowskiEngineBackend import MemoryManagerBackend

from SparseTensor import SparseTensor, SparseTensorOperationMode, SparseTensorQuantizationMode, \
set_sparse_tensor_operation_mode, sparse_tensor_operation_mode, clear_global_coords_man

from Common import RegionType, convert_to_int_tensor, convert_region_type, \
MinkowskiModuleBase, KernelGenerator, GlobalPoolingMode

from MinkowskiCoords import CoordsKey, CoordsManager
from MinkowskiCoords import CoordsKey, CoordsManager, set_memory_manager_backend

from MinkowskiConvolution import MinkowskiConvolutionFunction, MinkowskiConvolution, \
MinkowskiConvolutionTransposeFunction, MinkowskiConvolutionTranspose
Expand Down
41 changes: 25 additions & 16 deletions pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ void instantiate_func(py::module &m, const std::string &dtypestr) {
#endif
}

template <typename MapType>
void instantiate_coordsman(py::module &m) {
template <typename MapType> void instantiate_coordsman(py::module &m) {
std::string coords_name = std::string("CoordsManager");
py::class_<mink::CoordsManager<MapType>>(m, coords_name.c_str())
.def(py::init<int>())
.def(py::init<int, mink::MemoryManagerBackend>())
.def("existsCoordsKey",
(bool (mink::CoordsManager<MapType>::*)(py::object) const) &
mink::CoordsManager<MapType>::existsCoordsKey)
Expand All @@ -204,40 +204,44 @@ void instantiate_coordsman(py::module &m) {
#endif
.def("getCoordsMap", &mink::CoordsManager<MapType>::getCoordsMap)
.def("getUnionMap", &mink::CoordsManager<MapType>::getUnionMap)
.def("getCoordsSize", (int (mink::CoordsManager<MapType>::*)(py::object) const) &
mink::CoordsManager<MapType>::getCoordsSize)
.def("getCoordsSize",
(int (mink::CoordsManager<MapType>::*)(py::object) const) &
mink::CoordsManager<MapType>::getCoordsSize)
.def("getCoords", &mink::CoordsManager<MapType>::getCoords)
.def("getBatchSize", &mink::CoordsManager<MapType>::getBatchSize)
.def("getBatchIndices", &mink::CoordsManager<MapType>::getBatchIndices)
.def("getRowIndicesAtBatchIndex",
&mink::CoordsManager<MapType>::getRowIndicesAtBatchIndex)
.def("getRowIndicesPerBatch", &mink::CoordsManager<MapType>::getRowIndicesPerBatch)
.def("setOriginCoordsKey", &mink::CoordsManager<MapType>::setOriginCoordsKey)
.def("getRowIndicesPerBatch",
&mink::CoordsManager<MapType>::getRowIndicesPerBatch)
.def("setOriginCoordsKey",
&mink::CoordsManager<MapType>::setOriginCoordsKey)
.def("initializeCoords",
(uint64_t(mink::CoordsManager<MapType>::*)(at::Tensor, at::Tensor, at::Tensor,
py::object, const bool, const bool,
const bool, const bool)) &
(uint64_t(mink::CoordsManager<MapType>::*)(
at::Tensor, at::Tensor, at::Tensor, py::object, const bool,
const bool, const bool, const bool)) &
mink::CoordsManager<MapType>::initializeCoords,
py::call_guard<py::gil_scoped_release>())
.def("createStridedCoords", &mink::CoordsManager<MapType>::createStridedCoords)
.def("createStridedCoords",
&mink::CoordsManager<MapType>::createStridedCoords)
.def("createTransposedStridedRegionCoords",
&mink::CoordsManager<MapType>::createTransposedStridedRegionCoords)
.def("createPrunedCoords", &mink::CoordsManager<MapType>::createPrunedCoords)
.def("createOriginCoords", &mink::CoordsManager<MapType>::createOriginCoords)
.def("createPrunedCoords",
&mink::CoordsManager<MapType>::createPrunedCoords)
.def("createOriginCoords",
&mink::CoordsManager<MapType>::createOriginCoords)
.def("printDiagnostics", &mink::CoordsManager<MapType>::printDiagnostics)
.def("__repr__",
[](const mink::CoordsManager<MapType> &a) { return a.toString(); });
}

template <typename MapType>
void instantiate(py::module &m) {
template <typename MapType> void instantiate(py::module &m) {
instantiate_coordsman<MapType>(m);
instantiate_func<MapType, float>(m, std::string("f"));
instantiate_func<MapType, double>(m, std::string("d"));
}

template <typename MapType>
void bind_native(py::module &m) {
template <typename MapType> void bind_native(py::module &m) {
std::string name = std::string("CoordsKey");
py::class_<mink::CoordsKey>(m, name.c_str())
.def(py::init<>())
Expand All @@ -260,6 +264,11 @@ void bind_native(py::module &m) {
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<mink::MemoryManagerBackend>(m, "MemoryManagerBackend")
.value("CUDA", mink::MemoryManagerBackend::CUDA)
.value("PYTORCH", mink::MemoryManagerBackend::PYTORCH)
.export_values();

bind_native<mink::CoordsToIndexMap>(m);
instantiate<mink::CoordsToIndexMap>(m);
}
4 changes: 2 additions & 2 deletions src/coords_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ CoordsManager<MapType>::copyInOutMapToGPU(const InOutMaps<int> &map) {
pInOutMaps<int> d_map;

const int n = getInOutMapsSize(map);
int *d_scr = (int *)gpu_memory_manager.gpuMalloc(n * sizeof(int));
int *d_scr = (int *)gpu_memory_manager.get()->gpuMalloc(n * sizeof(int));

for (const auto &cmap : map) {
// Copy (*p_in_maps)[k] to GPU
Expand Down Expand Up @@ -166,7 +166,7 @@ vector<vector<at::Tensor>> CoordsManager<MapType>::getKernelMapGPU(
torch::TensorOptions()
.dtype(torch::kInt64)
// .device(torch::kCUDA)
.device(torch::kCUDA, gpu_memory_manager.device_id)
.device(torch::kCUDA, gpu_memory_manager.get()->get_device_id())
.requires_grad(false);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down
15 changes: 11 additions & 4 deletions src/coords_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,18 @@ template <typename MapType = CoordsToIndexMap> class CoordsManager {
unordered_map<InOutMapKey, InOutMaps<int>, InOutMapKeyHash> in_maps;
unordered_map<InOutMapKey, InOutMaps<int>, InOutMapKeyHash> out_maps;

CoordsManager(){};
CoordsManager(){
gpu_memory_manager = std::make_shared<GPUMemoryManager>();
};
CoordsManager(int num_threads) {
omp_set_dynamic(0);
omp_set_num_threads(num_threads);
}
CoordsManager(int num_threads, MemoryManagerBackend backend) {
omp_set_dynamic(0);
omp_set_num_threads(num_threads);
gpu_memory_manager = std::make_shared<GPUMemoryManager>(backend);
}
~CoordsManager() { clear(); }

void printDiagnostics(py::object py_coords_key) const;
Expand Down Expand Up @@ -263,7 +270,7 @@ template <typename MapType = CoordsToIndexMap> class CoordsManager {

#ifndef CPU_ONLY
// GPU memory manager
GPUMemoryManager gpu_memory_manager;
std::shared_ptr<GPUMemoryManager> gpu_memory_manager;

// Keep all in out maps throughout the lifecycle of the coords manager
//
Expand Down Expand Up @@ -294,10 +301,10 @@ template <typename MapType = CoordsToIndexMap> class CoordsManager {
py::object py_out_coords_key);

void *getScratchGPUMemory(size_t size) {
return gpu_memory_manager.tmp_data(size);
return gpu_memory_manager.get()->tmp_data(size);
}

void clearScratchGPUMemory() { gpu_memory_manager.clear_tmp(); }
void clearScratchGPUMemory() { gpu_memory_manager.get()->clear_tmp(); }

#endif // CPU_ONLY
}; // coordsmanager
Expand Down
Loading

0 comments on commit 1ddf529

Please sign in to comment.