Skip to content

Commit

Permalink
Explicit coordinate tracking and OpenMP (#84)
Browse files Browse the repository at this point in the history
* remove threads

* CoordsMan with coordinates, private var rename

* test kernelmap

* OpenMP kernelmap

* OMP kernel map

* example timer added

* added download print

* indoor example with timing

* version up

* changelog
  • Loading branch information
chrischoy authored Aug 28, 2019
1 parent bb2e869 commit a32433c
Show file tree
Hide file tree
Showing 31 changed files with 658 additions and 721 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
# Change Log


## [0.2.6] - 2019-08-28

Use OpenMP for multi-threaded kernel map generation and minor renaming and explicit coordinate management for future upgrades.

### Changed

- Minor name changes in `CoordsManager`.
- `CoordsManager` saves all coordinates for future updates.
- `CoordsManager` functions `createInOutPerKernel` and `createInOutPerKernelTranspose` now support multi-threaded kernel map generation by default using OpenMP.
- Thus, all manual thread functions such as `createInOutPerKernelInThreads`, `initialize_nthreads` removed.
- Use `export OMP_NUM_THREADS` to control the number of threads.


## [0.2.5a0] - 2019-07-12

### Changed

- Added the `MinkowskiBroadcast` and `MinkowskiBroadcastConcatenation` module.


## [0.2.5] - 2019-07-02

### Changed
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ CXXFLAGS += -MMD -MP
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
-DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=$(EXTENSION_NAME) \
-D_GLIBCXX_USE_CXX11_ABI=$(WITH_ABI)
CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
CXXFLAGS += -fopenmp -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
LINKFLAGS += -pthread -fPIC $(WARNINGS) -Wl,-rpath=$(PYTHON_LIB_DIR) -Wl,--no-as-needed -Wl,--sysroot=/
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
Expand Down
20 changes: 2 additions & 18 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
import MinkowskiEngineBackend as MEB


def initialize_nthreads(num_threads, D):
assert num_threads > 0
getattr(MEB, f'PyCoordsManager{D}int32')(num_threads)


class CoordsKey():

def __init__(self, D):
Expand All @@ -57,15 +52,12 @@ def __repr__(self):

class CoordsManager():

def __init__(self, num_threads=None, D=-1):
def __init__(self, D=-1):
if D < 1:
raise ValueError(f"Invalid dimension {D}")
self.D = D
CPPCoordsManager = getattr(MEB, f'PyCoordsManager{D}int32')
if num_threads:
coords_man = CPPCoordsManager(num_threads)
else:
coords_man = CPPCoordsManager()
coords_man = CPPCoordsManager()
self.CPPCoordsManager = coords_man

def initialize(self, coords, coords_key, enforce_creation=False):
Expand Down Expand Up @@ -156,14 +148,6 @@ def get_mapping_by_tensor_strides(self, in_tensor_strides,
out_key = self.get_coords_key(out_tensor_strides)
return self.get_mapping_by_coords_key(in_key, out_key)

def get_mapping_by_coords_key(self, in_coords_key, out_coords_key):
assert isinstance(in_coords_key, CoordsKey) \
and isinstance(out_coords_key, CoordsKey)
mapping = torch.IntTensor()
self.CPPCoordsManager.getCoordsMapping(
mapping, in_coords_key.CPPCoordsKey, out_coords_key.CPPCoordsKey)
return mapping

def permute_label(self,
label,
max_label,
Expand Down
1 change: 1 addition & 0 deletions MinkowskiEngine/SparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __power__(self, power):

def __repr__(self):
return self.__class__.__name__ + '(' + os.linesep \
+ ' Coords=' + str(self.C) + os.linesep \
+ ' Feats=' + str(self.F) + os.linesep \
+ ' coords_key=' + str(self.coords_key) + os.linesep \
+ ' tensor_stride=' + str(self.coords_key.getTensorStride()) + os.linesep \
Expand Down
4 changes: 2 additions & 2 deletions MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
__version__ = "0.2.5a0"
__version__ = "0.2.6"

import os
import sys
Expand All @@ -37,7 +37,7 @@
from Common import RegionType, convert_to_int_tensor, convert_region_type, \
MinkowskiModuleBase, KernelGenerator

from MinkowskiCoords import CoordsKey, CoordsManager, initialize_nthreads
from MinkowskiCoords import CoordsKey, CoordsManager

from MinkowskiConvolution import MinkowskiConvolutionFunction, MinkowskiConvolution, \
MinkowskiConvolutionTransposeFunction, MinkowskiConvolutionTranspose
Expand Down
33 changes: 33 additions & 0 deletions examples/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,43 @@
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
import numpy as np
import time

import torch


class Timer(object):
"""A simple timer."""

def __init__(self):
self.reset()

def reset(self):
self.total_time = 0
self.calls = 0
self.start_time = 0
self.diff = 0
self.averate_time = 0
self.min_time = np.Inf

def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()

def toc(self, average=False):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if self.diff < self.min_time:
self.min_time = self.diff
if average:
return self.average_time
else:
return self.diff


def get_coords(data, batch_index=0):
coords = []
for i, row in enumerate(data):
Expand Down
43 changes: 32 additions & 11 deletions examples/indoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
import torch
import MinkowskiEngine as ME
from examples.minkunet import MinkUNet34C
from examples.common import Timer

# Check if the weights and file exist and download
if not os.path.isfile('weights.pth'):
urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/weights.pth", 'weights.pth')
print('Downloading weights and a room ply file...')
urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/weights.pth",
'weights.pth')
urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/1.ply", '1.ply')

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -107,6 +110,16 @@ def load_file(file_name, voxel_size):
return quantized_coords[inds], feats[inds], pcd


def generate_input_sparse_tensor(file_name, voxel_size=0.05):
# Create a batch, this process is done in a data loader during training in parallel.
batch = [load_file(file_name, voxel_size)]
coordinates_, featrues_, pcds = list(zip(*batch))
coordinates, features = ME.utils.sparse_collate(coordinates_, featrues_)

# Normalize features and create a sparse tensor
return ME.SparseTensor(features - 0.5, coords=coordinates).to(device)


if __name__ == '__main__':
config = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -115,17 +128,25 @@ def load_file(file_name, voxel_size):
model = MinkUNet34C(3, 20).to(device)
model_dict = torch.load(config.weights)
model.load_state_dict(model_dict)

# Create a batch, this process is done in a data loader during training in parallel.
batch = [load_file(config.file_name, 0.02)]
coordinates_, featrues_, pcds = list(zip(*batch))
coordinates, features = ME.utils.sparse_collate(coordinates_, featrues_)

# Normalize features and create a sparse tensor
sinput = ME.SparseTensor(features - 0.5, coords=coordinates).to(device)
model.eval()

# Measure time
for voxel_size in [0.1, 0.05, 0.02]:
timer = Timer()
sinput = generate_input_sparse_tensor(
config.file_name, voxel_size=voxel_size)

# Feed-forward pass and get the prediction
for i in range(4):
timer.tic()
soutput = model(sinput)
timer.toc()
print(
f'Time to process a room with {voxel_size}m voxel downsampling '
f'containing {len(sinput)} voxels: {timer.min_time}'
)

# Feed-forward pass and get the prediction
soutput = model(sinput)
_, pred = soutput.F.max(1)
pred = pred.cpu().numpy()

Expand All @@ -139,7 +160,7 @@ def load_file(file_name, voxel_size):
pred_pcd.colors = o3d.Vector3dVector(colors / 255)

# Move the original point cloud
pcd = pcds[0]
pcd = o3d.read_point_cloud(config.file_name)
pcd.points = o3d.Vector3dVector(np.array(pcd.points) + np.array([0, 5, 0]))

# Visualize the input point cloud and the prediction
Expand Down
2 changes: 0 additions & 2 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,4 @@ std::vector<py::array_t<int>>
SparseVoxelization(py::array_t<uint64_t, py::array::c_style> keys,
py::array_t<int, py::array::c_style> labels,
int ignore_label, bool has_label);

void cuda_thread_exit(void);
#endif
3 changes: 0 additions & 3 deletions pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ void instantiate_dim_itype(py::module &m, const std::string &dim,
std::string coords_name = std::string("PyCoordsManager") + dim + itypestr;
py::class_<CoordsManager<D, Itype>>(m, coords_name.c_str())
.def(py::init<>())
.def(py::init<int>())
.def("existsCoordsKey", (bool (CoordsManager<D, Itype>::*)(py::object)) &
CoordsManager<D, Itype>::existsCoordsKey)
.def("getCoordsKey", &CoordsManager<D, Itype>::getCoordsKey)
Expand All @@ -144,7 +143,6 @@ void instantiate_dim_itype(py::module &m, const std::string &dim,
.def("initializeCoords", (uint64_t(CoordsManager<D, Itype>::*)(
at::Tensor, py::object, bool)) &
CoordsManager<D, Itype>::initializeCoords)
.def("getCoordsMapping", &CoordsManager<D, Itype>::getCoordsMapping)
.def("__repr__",
[](const CoordsManager<D, Itype> &a) { return a.toString(); });
}
Expand Down Expand Up @@ -172,7 +170,6 @@ void instantiate_dim(py::module &m, const std::string &dim) {
void bind_native(py::module &m) {
#ifndef CPU_ONLY
m.def("SparseVoxelization", &SparseVoxelization);
m.def("CUDAThreadExit", &cuda_thread_exit);
py::class_<GPUMemoryManager<int32_t> >(m, "MemoryManager")
.def(py::init<>())
.def("size", &GPUMemoryManager<int32_t>::size)
Expand Down
28 changes: 14 additions & 14 deletions src/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ void BroadcastForwardCPU(at::Tensor in_feat, at::Tensor in_feat_glob,
InOutMapKey map_key = p_coords_manager->getOriginMapHashKeyCheck(
py_in_coords_key, py_out_coords_key);

if (p_coords_manager->in_maps.find(map_key) ==
p_coords_manager->in_maps.end())
if (p_coords_manager->_in_maps.find(map_key) ==
p_coords_manager->_in_maps.end())
throw std::invalid_argument(
Formatter() << "Input Output map not found: "
<< std::to_string(hash_vec<InOutMapKey>(map_key)));
Expand All @@ -56,7 +56,7 @@ void BroadcastForwardCPU(at::Tensor in_feat, at::Tensor in_feat_glob,
BroadcastForwardKernelCPU<Dtype, Itype>(
in_feat.data<Dtype>(), in_feat.size(0), in_feat_glob.data<Dtype>(),
in_feat_glob.size(0), out_feat.data<Dtype>(), in_feat.size(1), op,
p_coords_manager->in_maps[map_key], p_coords_manager->out_maps[map_key]);
p_coords_manager->_in_maps[map_key], p_coords_manager->_out_maps[map_key]);
}

template <uint8_t D, typename Dtype, typename Itype>
Expand All @@ -73,8 +73,8 @@ void BroadcastBackwardCPU(at::Tensor in_feat, at::Tensor grad_in_feat,
InOutMapKey map_key = p_coords_manager->getOriginMapHashKeyCheck(
py_in_coords_key, py_out_coords_key);

if (p_coords_manager->in_maps.find(map_key) ==
p_coords_manager->in_maps.end())
if (p_coords_manager->_in_maps.find(map_key) ==
p_coords_manager->_in_maps.end())
throw std::invalid_argument(
Formatter() << "Input Output map not found: "
<< std::to_string(hash_vec<InOutMapKey>(map_key)));
Expand All @@ -88,7 +88,7 @@ void BroadcastBackwardCPU(at::Tensor in_feat, at::Tensor grad_in_feat,
in_feat.data<Dtype>(), grad_in_feat.data<Dtype>(), in_feat.size(0),
in_feat_glob.data<Dtype>(), grad_in_feat_glob.data<Dtype>(),
in_feat_glob.size(0), grad_out_feat.data<Dtype>(), in_feat.size(1), op,
p_coords_manager->in_maps[map_key], p_coords_manager->out_maps[map_key]);
p_coords_manager->_in_maps[map_key], p_coords_manager->_out_maps[map_key]);
}

#ifndef CPU_ONLY
Expand All @@ -105,8 +105,8 @@ void BroadcastForwardGPU(at::Tensor in_feat, at::Tensor in_feat_glob,
InOutMapKey map_key = p_coords_manager->getOriginMapHashKeyCheck(
py_in_coords_key, py_out_coords_key);

if (p_coords_manager->in_maps.find(map_key) ==
p_coords_manager->in_maps.end())
if (p_coords_manager->_in_maps.find(map_key) ==
p_coords_manager->_in_maps.end())
throw std::invalid_argument(
Formatter() << "Input Output map not found: "
<< std::to_string(hash_vec<InOutMapKey>(map_key)));
Expand All @@ -115,15 +115,15 @@ void BroadcastForwardGPU(at::Tensor in_feat, at::Tensor in_feat_glob,
out_feat.zero_();

Itype *d_scr = p_coords_manager->getScratchGPUMemory(
p_coords_manager->out_maps[map_key][0].size());
p_coords_manager->_out_maps[map_key][0].size());

cusparseHandle_t handle =
THCState_getCurrentSparseHandle(at::globalContext().getTHCState());

BroadcastForwardKernelGPU<Dtype, Itype>(
in_feat.data<Dtype>(), in_feat.size(0), in_feat_glob.data<Dtype>(),
in_feat_glob.size(0), out_feat.data<Dtype>(), in_feat.size(1), op,
p_coords_manager->in_maps[map_key], p_coords_manager->out_maps[map_key],
p_coords_manager->_in_maps[map_key], p_coords_manager->_out_maps[map_key],
d_scr, handle, at::cuda::getCurrentCUDAStream());
}

Expand All @@ -142,8 +142,8 @@ void BroadcastBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
InOutMapKey map_key = p_coords_manager->getOriginMapHashKeyCheck(
py_in_coords_key, py_out_coords_key);

if (p_coords_manager->in_maps.find(map_key) ==
p_coords_manager->in_maps.end())
if (p_coords_manager->_in_maps.find(map_key) ==
p_coords_manager->_in_maps.end())
throw std::invalid_argument(
Formatter() << "Input Output map not found: "
<< std::to_string(hash_vec<InOutMapKey>(map_key)));
Expand All @@ -154,7 +154,7 @@ void BroadcastBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
grad_in_feat_glob.zero_();

Itype *d_scr = p_coords_manager->getScratchGPUMemory(
2 * p_coords_manager->out_maps[map_key][0].size() + // in_map + out_map
2 * p_coords_manager->_out_maps[map_key][0].size() + // in_map + out_map
in_feat_glob.size(0) + 1 // d_csr_row
);

Expand All @@ -171,7 +171,7 @@ void BroadcastBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
in_feat.data<Dtype>(), grad_in_feat.data<Dtype>(), in_feat.size(0),
in_feat_glob.data<Dtype>(), grad_in_feat_glob.data<Dtype>(),
in_feat_glob.size(0), grad_out_feat.data<Dtype>(), in_feat.size(1), op,
p_coords_manager->in_maps[map_key], p_coords_manager->out_maps[map_key],
p_coords_manager->_in_maps[map_key], p_coords_manager->_out_maps[map_key],
d_scr, d_dscr, handle, at::cuda::getCurrentCUDAStream());

// p_coords_manager->gpu_memory_manager.reset();
Expand Down
1 change: 0 additions & 1 deletion src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

#include "coords_manager.hpp"
#include "instantiation.hpp"
#include "thread_pool.hpp"
#include "types.hpp"
#include "utils.hpp"

Expand Down
Loading

0 comments on commit a32433c

Please sign in to comment.