Skip to content

Commit

Permalink
ConcurrentCoordsMap quantization
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit e45a3dd27042710d13511333d7ac2fdbecc80d72
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Wed Dec 11 13:03:56 2019 -0800

    quantization with label

commit 3da9f51d6d3ecddfe9864e06e74d3640f8660663
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Wed Dec 11 01:04:44 2019 -0800

    quantization
  • Loading branch information
chrischoy committed Dec 11, 2019
1 parent c14862e commit 81aacbb
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 349 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Change Log

## [nightly] - 2019-12-10
## [nightly] - 2019-12-11

### Changed

- Cache in-out mapping on device
- Latest TBB installation instruction update and GCC requirement
- ConcurrentCoordsMap based quantization with label collision


## [0.3.0] - 2019-12-08
Expand Down
2 changes: 1 addition & 1 deletion MinkowskiEngine/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
from .voxelization import sparse_quantize, ravel_hash_vec, fnv_hash_vec
from .quantization import sparse_quantize, ravel_hash_vec, fnv_hash_vec
from .collation import SparseCollation, batched_coordinates, sparse_collate
from .gradcheck import gradcheck
from .init import kaiming_normal_
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import numpy as np
from collections import Sequence
import MinkowskiEngineBackend as MEB


def fnv_hash_vec(arr):
Expand Down Expand Up @@ -65,9 +66,7 @@ def sparse_quantize(coords,
feats=None,
labels=None,
ignore_label=255,
set_ignore_label_when_collision=False,
return_index=False,
hash_type='fnv',
quantization_size=1):
r"""Given coordinates, and features (optionally labels), the function
generates quantized (voxelized) coordinates.
Expand All @@ -85,15 +84,9 @@ def sparse_quantize(coords,
ignore_label (:attr:`int`, optional): the int value of the IGNORE LABEL.
set_ignore_label_when_collision (:attr:`bool`, optional): use the `ignore_label`
when at least two points fall into the same cell.
return_index (:attr:`bool`, optional): True if you want the indices of the
quantized coordinates. False by default.
hash_type (:attr:`str`, optional): Hash function used for quantization. Either
`ravel` or `fnv`. `ravel` by default.
quantization_size (:attr:`float`, :attr:`list`, or
:attr:`numpy.ndarray`, optional): the length of the each side of the
hyperrectangle of of the grid cell.
Expand All @@ -104,17 +97,18 @@ def sparse_quantize(coords,
"""
use_label = labels is not None
use_feat = feats is not None

# If only coordindates are given, return the index
if not use_label and not use_feat:
return_index = True

assert hash_type in [
'ravel', 'fnv'
], "Invalid hash_type. Either ravel, or fnv allowed. You put hash_type=" + hash_type
assert coords.ndim == 2, \
"The coordinates must be a 2D matrix. The shape of the input is " + str(coords.shape)

if use_feat:
assert feats.ndim == 2
assert coords.shape[0] == feats.shape[0]

if use_label:
assert coords.shape[0] == len(labels)

Expand All @@ -124,34 +118,38 @@ def sparse_quantize(coords,
assert len(
quantization_size
) == dimension, "Quantization size and coordinates size mismatch."
quantization_size = [i for i in quantization_size]
quantization_size = np.array([i for i in quantization_size])
discrete_coords = np.floor(coords / quantization_size)
elif np.isscalar(quantization_size): # Assume that it is a scalar
quantization_size = [quantization_size for i in range(dimension)]
else:
raise ValueError('Not supported type for quantization_size.')
discrete_coords = np.floor(coords / np.array(quantization_size))

# Hash function type
if hash_type == 'ravel':
key = ravel_hash_vec(discrete_coords)
if quantization_size == 1:
discrete_coords = coords
else:
quantization_size = np.array(
[quantization_size for i in range(dimension)])
discrete_coords = np.floor(coords / quantization_size)
else:
key = fnv_hash_vec(discrete_coords)
raise ValueError('Not supported type for quantization_size.')

# Return values accordingly
if use_label:
_, inds, counts = np.unique(key, return_index=True, return_counts=True)
filtered_labels = labels[inds]
if set_ignore_label_when_collision:
filtered_labels[counts > 1] = ignore_label
mapping, colabels = MEB.quantize_label(discrete_coords, labels,
ignore_label)

if return_index:
return inds, filtered_labels
return mapping, colabels
else:
return discrete_coords[inds], feats[inds], filtered_labels
if use_feat:
return discrete_coords[mapping], feats[mapping], colabels
else:
return discrete_coords[mapping], colabels

else:
_, inds = np.unique(key, return_index=True)
mapping = MEB.quantize(discrete_coords)
if return_index:
return inds
return mapping
else:
if use_feat:
return discrete_coords[inds], feats[inds]
return discrete_coords[mapping], feats[mapping]
else:
return discrete_coords[inds]
return discrete_coords[mapping]
57 changes: 31 additions & 26 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
* of the code.
*/
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>

Expand Down Expand Up @@ -255,31 +256,35 @@ void GlobalPoolingBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
* GlobalMaxPooling
*************************************/
template <typename Dtype>
void GlobalMaxPoolingForwardCPU(
at::Tensor in_feat, at::Tensor out_feat, at::Tensor num_nonzero,
py::object py_in_coords_key, py::object py_out_coords_key,
py::object py_coords_manager);
void GlobalMaxPoolingForwardCPU(at::Tensor in_feat, at::Tensor out_feat,
at::Tensor num_nonzero,
py::object py_in_coords_key,
py::object py_out_coords_key,
py::object py_coords_manager);

template <typename Dtype>
void GlobalMaxPoolingBackwardCPU(
at::Tensor in_feat, at::Tensor grad_in_feat,
at::Tensor grad_out_feat, at::Tensor num_nonzero,
py::object py_in_coords_key, py::object py_out_coords_key,
py::object py_coords_manager);
void GlobalMaxPoolingBackwardCPU(at::Tensor in_feat, at::Tensor grad_in_feat,
at::Tensor grad_out_feat,
at::Tensor num_nonzero,
py::object py_in_coords_key,
py::object py_out_coords_key,
py::object py_coords_manager);

#ifndef CPU_ONLY
template <typename Dtype>
void GlobalMaxPoolingForwardGPU(
at::Tensor in_feat, at::Tensor out_feat, at::Tensor num_nonzero,
py::object py_in_coords_key, py::object py_out_coords_key,
py::object py_coords_manager);
void GlobalMaxPoolingForwardGPU(at::Tensor in_feat, at::Tensor out_feat,
at::Tensor num_nonzero,
py::object py_in_coords_key,
py::object py_out_coords_key,
py::object py_coords_manager);

template <typename Dtype>
void GlobalMaxPoolingBackwardGPU(
at::Tensor in_feat, at::Tensor grad_in_feat,
at::Tensor grad_out_feat, at::Tensor num_nonzero,
py::object py_in_coords_key, py::object py_out_coords_key,
py::object py_coords_manager);
void GlobalMaxPoolingBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
at::Tensor grad_out_feat,
at::Tensor num_nonzero,
py::object py_in_coords_key,
py::object py_out_coords_key,
py::object py_coords_manager);
#endif

/*************************************
Expand Down Expand Up @@ -347,12 +352,12 @@ void PruningBackwardGPU(at::Tensor grad_in_feat, at::Tensor grad_out_feat,
#endif

/*************************************
* Voxelization
* Quantization
*************************************/
#ifndef CPU_ONLY
#include <pybind11/numpy.h>
std::vector<py::array_t<int>>
SparseVoxelization(py::array_t<uint64_t, py::array::c_style> keys,
py::array_t<int, py::array::c_style> labels,
int ignore_label, bool has_label);
#endif
py::array
quantize(py::array_t<int, py::array::c_style | py::array::forcecast> coords);

vector<py::array> quantize_label(
py::array_t<int, py::array::c_style | py::array::forcecast> coords,
py::array_t<int, py::array::c_style | py::array::forcecast> labels,
int invalid_label);
5 changes: 4 additions & 1 deletion pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ void instantiate(py::module &m) {

void bind_native(py::module &m) {
#ifndef CPU_ONLY
m.def("SparseVoxelization", &SparseVoxelization);
py::class_<GPUMemoryManager>(m, "MemoryManager")
.def(py::init<>())
.def("resize", &GPUMemoryManager::resize);
Expand All @@ -186,6 +185,10 @@ void bind_native(py::module &m) {
.def("setTensorStride", &CoordsKey::setTensorStride)
.def("getTensorStride", &CoordsKey::getTensorStride)
.def("__repr__", [](const CoordsKey &a) { return a.toString(); });

// Quantization
m.def("quantize", &quantize);
m.def("quantize_label", &quantize_label);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down
62 changes: 24 additions & 38 deletions src/concurrent_coordsmap.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
/* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*
* Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
* Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
* of the code.
*/
#include "concurrent_coordsmap.hpp"

void ConcurrentCoordsMap::set_threads(int num_threads_) {
Expand Down Expand Up @@ -81,13 +105,6 @@ ConcurrentCoordsMap::initialize(vector<int> &&coords_, int nrows_, int ncols_,
for (const auto &i : concurr_batch_indices)
batch_indices.insert(i.first);

/* for (const auto &i : map) {
cout << "Init: " << &coords[ncols * i.second] << " " << i.first.ptr << " ";
for (int k = 0; k < i.first.size; k++)
cout << i.first.ptr[k] << " ";
cout << ":" << i.second << endl;
} */

// cout << "Creation: " << (t1 - t0).seconds() << endl;
int unique_size = map.size();

Expand Down Expand Up @@ -240,13 +257,6 @@ ConcurrentCoordsMap ConcurrentCoordsMap::stride_region(const Region &region) {
},
tbb::auto_partitioner());

/* for (const auto &i : tmp_map) {
cout << "Strided : " << i.first.ptr << " ";
for (int k = 0; k < i.first.size; k++)
cout << i.first.ptr[k] << " ";
cout << ":" << i.second << endl;
} */

// p_mapping
tbb::concurrent_vector<int *> p_mapping;
p_mapping.reserve(tmp_map.size());
Expand Down Expand Up @@ -286,14 +296,6 @@ ConcurrentCoordsMap ConcurrentCoordsMap::stride_region(const Region &region) {
},
tbb::auto_partitioner());

/* for (const auto &i : out_inner_map) {
cout << "Coords Strided : " << &out_coords[i.second * ncols] << " "
<< i.first.ptr << " ";
for (int k = 0; k < i.first.size; k++)
cout << i.first.ptr[k] << " " << out_coords[i.second * ncols + k] << " ";
cout << ":" << i.second << endl;
} */

ASSERT(out_coords_map.size() == out_coords_map.nrows, "Map size mismatch");

return out_coords_map;
Expand Down Expand Up @@ -355,14 +357,6 @@ void ConcurrentCoordsMap::updateUniqueCoords(
ConcurrentCoordsInnerMap new_map;
vector<int> new_coords(map.size() * ncols);

/* for (const auto &i : map) {
cout << "Before unique: " << &coords[ncols * i.second] << " " << i.first.ptr
<< " ";
for (int k = 0; k < i.first.size; k++)
cout << i.first.ptr[k] << " ";
cout << ":" << i.second << endl;
} */

tbb::parallel_for(map.range(), [&](decltype(map)::const_range_type &r) {
Coord<int> coord;
coord.size = ncols;
Expand All @@ -375,14 +369,6 @@ void ConcurrentCoordsMap::updateUniqueCoords(
}
});

/* for (const auto &i : new_map) {
cout << "After unique: " << &out_coords[ncols * i.second] << " "
<< i.first.ptr << " ";
for (int k = 0; k < i.first.size; k++)
cout << i.first.ptr[k] << " ";
cout << ":" << i.second << endl;
} */

ASSERT(map.size() == new_map.size(), "Remapping sizes different.")

// Must move to replace the map explicitly.
Expand Down
29 changes: 29 additions & 0 deletions src/concurrent_coordsmap.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
/* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*
* Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
* Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
* of the code.
*/
#ifndef CONCURRENT_COORDSMAP
#define CONCURRENT_COORDSMAP

#include <tbb/blocked_range.h>
#include <tbb/concurrent_unordered_map.h>
#include <tbb/concurrent_vector.h>
Expand Down Expand Up @@ -135,6 +162,7 @@ class ConcurrentCoordsMap {
const vector<int> &tensor_strides) const;

// Iterators
ConcurrentCoordsInnerMap::const_range_type range() { return map.range(); }
ConcurrentCoordsInnerMap::iterator begin() { return map.begin(); }
ConcurrentCoordsInnerMap::const_iterator begin() const { return map.begin(); }
ConcurrentCoordsInnerMap::iterator end() { return map.end(); }
Expand All @@ -157,3 +185,4 @@ class ConcurrentCoordsMap {

void print() const;
};
#endif // CONCURRENT_COORDSMAP
Loading

0 comments on commit 81aacbb

Please sign in to comment.