Skip to content

Commit

Permalink
inverse quantization map (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Apr 2, 2020
1 parent 573b16c commit 9d8dfae
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 32 deletions.
68 changes: 48 additions & 20 deletions MinkowskiEngine/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ def ravel_hash_vec(arr):


def quantize(coords):
r"""Returns a unique index map and an inverse index map.
Args:
:attr:`coords` (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a
matrix of size :math:`N \times D` where :math:`N` is the number of
points in the :math:`D` dimensional space.
Returns:
:attr:`unique_map` (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a
list of indices that defines unique coordinates.
:attr:`coords[unique_map]` is the unique coordinates.
:attr:`inverse_map` (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a
list of indices that defines the inverse map that recovers the original
coordinates. :attr:`coords[unique_map[inverse_map]] == coords`
Example::
>>> unique_map, inverse_map = quantize(coords)
>>> unique_coords = coords[unique_map]
>>> print(unique_coords[inverse_map] == coords) # True, ..., True
>>> print(coords[unique_map[inverse_map]] == coords) # True, ..., True
"""
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
"Invalid coords type"
if isinstance(coords, np.ndarray):
Expand Down Expand Up @@ -92,6 +116,7 @@ def sparse_quantize(coords,
labels=None,
ignore_label=-100,
return_index=False,
return_inverse=False,
quantization_size=None):
r"""Given coordinates, and features (optionally labels), the function
generates quantized (voxelized) coordinates.
Expand All @@ -117,8 +142,18 @@ def sparse_quantize(coords,
IGNORE LABEL.
:attr:`torch.nn.CrossEntropyLoss(ignore_index=ignore_label)`
:attr:`return_index` (:attr:`bool`, optional): True if you want the indices of the
quantized coordinates. False by default.
:attr:`return_index` (:attr:`bool`, optional): set True if you want the
indices of the quantized coordinates. False by default.
:attr:`return_inverse` (:attr:`bool`, optional): set True if you want
the indices that can recover the discretized original coordinates.
False by default. `return_index` must be True when `return_reverse` is True.
Example::
>>> unique_map, inverse_map = sparse_quantize(discrete_coords, return_index=True, return_inverse=True)
>>> unique_coords = discrete_coords[unique_map]
>>> print(unique_coords[inverse_map] == discrete_coords) # True
:attr:`quantization_size` (:attr:`float`, :attr:`list`, or
:attr:`numpy.ndarray`, optional): the length of the each side of the
Expand Down Expand Up @@ -150,6 +185,9 @@ def sparse_quantize(coords,
assert coords.ndim == 2, \
"The coordinates must be a 2D matrix. The shape of the input is " + str(coords.shape)

if return_inverse:
assert return_index, "return_reverse must be set with return_index"

if use_feat:
assert feats.ndim == 2
assert coords.shape[0] == feats.shape[0]
Expand Down Expand Up @@ -202,24 +240,14 @@ def sparse_quantize(coords,
return discrete_coords[mapping], colabels

else:
mapping = quantize(discrete_coords)
if len(mapping) > 0:
if return_index:
return mapping
unique_map, inverse_map = quantize(discrete_coords)
if return_index:
if return_inverse:
return unique_map, inverse_map
else:
if use_feat:
return discrete_coords[mapping], feats[mapping]
else:
return discrete_coords[mapping]

return unique_map
else:
if return_index:
if isinstance(discrete_coords, np.ndarray):
return np.arange(len(discrete_coords))
else:
return torch.arange(len(discrete_coords), dtype=torch.long)
if use_feat:
return discrete_coords[unique_map], feats[unique_map]
else:
if use_feat:
return discrete_coords, feats
else:
return discrete_coords
return discrete_coords[unique_map]
4 changes: 2 additions & 2 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,15 @@ UnionBackwardGPU(at::Tensor grad_out_feat, vector<py::object> py_in_coords_keys,
/*************************************
* Quantization
*************************************/
vector<int>
vector<py::array>
quantize_np(py::array_t<int, py::array::c_style | py::array::forcecast> coords);

vector<py::array> quantize_label_np(
py::array_t<int, py::array::c_style | py::array::forcecast> coords,
py::array_t<int, py::array::c_style | py::array::forcecast> labels,
int invalid_label);

at::Tensor quantize_th(at::Tensor coords);
vector<at::Tensor> quantize_th(at::Tensor coords);

vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
int invalid_label);
Expand Down
43 changes: 42 additions & 1 deletion src/coordsmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,47 @@ CoordsMap::initialize_batch(const int *p_coords, const int nrows_,
return make_pair(mapping, batch_indices);
}

// index, inverse_index = initialize_with_inverse(coords)
// unique_coords = coords[index]
// coords == unique_coords[inverse_index]
// coords == coords[index[inverse_index]]
tuple<vector<int>, vector<int>, set<int>>
CoordsMap::initialize_batch_with_inverse(const int *p_coords, const int nrows_,
const int ncols_) {
nrows = nrows_;
ncols = ncols_;

vector<int> mapping, inverse_mapping;
set<int> batch_indices;

mapping.reserve(nrows);
inverse_mapping.reserve(nrows);

int c = 0;
for (int i = 0; i < nrows; i++) {
vector<int> coord(ncols);
std::copy_n(p_coords + i * ncols, ncols, coord.data());

auto iter = map.find(coord);
if (iter == map.end()) {
mapping.push_back(i);
inverse_mapping.push_back(c);

#ifdef BATCH_FIRST
batch_indices.insert(coord[0]);
#else
batch_indices.insert(coord[ncols - 1]);
#endif
map[move(coord)] = c++;
} else {
inverse_mapping.push_back(iter->second);
}
}

return std::make_tuple(move(mapping), move(inverse_mapping),
move(batch_indices));
}

CoordsMap CoordsMap::stride(const vector<int> &tensor_strides) const {
ASSERT(tensor_strides.size() == ncols - 1, "Invalid tensor strides");

Expand Down Expand Up @@ -175,7 +216,7 @@ CoordsMap::union_coords(const vector<reference_wrapper<CoordsMap>> &maps) {
const auto max_index = std::distance(maps.begin(), max_iter);

// Initialize with the largest coords map.
const CoordsMap& max_map = maps[max_index];
const CoordsMap &max_map = maps[max_index];
CoordsMap out_map(max_map);
out_map.reserve(num_tot);
size_t c = max_map.size();
Expand Down
6 changes: 6 additions & 0 deletions src/coordsmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cmath>
#include <memory>
#include <set>
#include <tuple>

#include "3rdparty/robin_hood.h"

Expand All @@ -38,6 +39,7 @@ namespace minkowski {

using std::reference_wrapper;
using std::set;
using std::tuple;
using std::vector;

template <typename Itype> struct byte_hash_vec {
Expand Down Expand Up @@ -90,6 +92,10 @@ class CoordsMap {
const int ncols_,
const bool force_remap = false);

tuple<vector<int>, vector<int>, set<int>>
initialize_batch_with_inverse(const int *p_coords_, const int nrows_,
const int ncols_);

// Generate strided version of the input coordinate map.
// returns mapping: out_coord row index to in_coord row index
CoordsMap stride(const vector<int> &tensor_strides) const;
Expand Down
38 changes: 31 additions & 7 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct IndexLabel {
using CoordsLabelMap =
robin_hood::unordered_flat_map<vector<int>, IndexLabel, byte_hash_vec<int>>;

vector<int> quantize_np(
vector<py::array> quantize_np(
py::array_t<int, py::array::c_style | py::array::forcecast> coords) {
py::buffer_info coords_info = coords.request();
auto &shape = coords_info.shape;
Expand All @@ -58,34 +58,58 @@ vector<int> quantize_np(

// Create coords map
CoordsMap map;
vector<int> mapping = map.initialize(p_coords, nrows, ncols, false);
auto results = map.initialize_batch_with_inverse(p_coords, nrows, ncols);
auto &mapping = std::get<0>(results);
auto &inverse_mapping = std::get<1>(results);

// Copy the concurrent vector to std vector
py::array_t<int> py_mapping = py::array_t<int>(mapping.size());
py::array_t<int> py_inverse_mapping =
py::array_t<int>(inverse_mapping.size());

py::buffer_info py_mapping_info = py_mapping.request();
py::buffer_info py_inverse_mapping_info = py_inverse_mapping.request();
int *p_py_mapping = (int *)py_mapping_info.ptr;
int *p_py_inverse_mapping = (int *)py_inverse_mapping_info.ptr;

std::copy_n(mapping.data(), mapping.size(), p_py_mapping);
std::copy_n(inverse_mapping.data(), inverse_mapping.size(),
p_py_inverse_mapping);

// mapping is empty when coords are all unique
return mapping;
return {py_mapping, py_inverse_mapping};
}

at::Tensor quantize_th(at::Tensor coords) {
vector<at::Tensor> quantize_th(at::Tensor coords) {
ASSERT(coords.dtype() == torch::kInt32,
"Coordinates must be an int type tensor.");
ASSERT(coords.dim() == 2,
"Coordinates must be represnted as a matrix. Dimensions: ",
coords.dim(), "!= 2.");

CoordsMap map;
vector<int> mapping =
map.initialize(coords.data<int>(), coords.size(0), coords.size(1), false);
auto results = map.initialize_batch_with_inverse(
coords.data<int>(), coords.size(0), coords.size(1));
auto mapping = std::get<0>(results);
auto inverse_mapping = std::get<1>(results);

// Long tensor for for easier indexing
auto th_mapping = torch::empty({(long)mapping.size()},
torch::TensorOptions().dtype(torch::kInt64));
auto th_inverse_mapping =
torch::empty({(long)inverse_mapping.size()},
torch::TensorOptions().dtype(torch::kInt64));
auto a_th_mapping = th_mapping.accessor<long int, 1>();
auto a_th_inverse_mapping = th_inverse_mapping.accessor<long int, 1>();

// Copy the output
for (size_t i = 0; i < mapping.size(); ++i)
a_th_mapping[i] = mapping[i];
for (size_t i = 0; i < inverse_mapping.size(); ++i)
a_th_inverse_mapping[i] = inverse_mapping[i];

// mapping is empty when coords are all unique
return th_mapping;
return {th_mapping, th_inverse_mapping};
}

vector<py::array> quantize_label_np(
Expand Down
10 changes: 8 additions & 2 deletions tests/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ def test(self):
def test_mapping(self):
N = 16575
coords = (np.random.rand(N, 3) * 100).astype(np.int32)
mapping = MEB.quantize_np(coords)
mapping, inverse_mapping = MEB.quantize_np(coords)
print('N unique:', len(mapping), 'N:', N)
self.assertTrue((coords == coords[mapping[inverse_mapping]]).all())

mapping = MEB.quantize_th(torch.from_numpy(coords))
coords = torch.from_numpy(coords)
mapping, inverse_mapping = MEB.quantize_th(coords)
print('N unique:', len(mapping), 'N:', N)
self.assertTrue((coords == coords[mapping[inverse_mapping]]).all())

index, reverse_index = sparse_quantize(coords, return_index=True, return_inverse=True)
self.assertTrue((coords == coords[mapping[inverse_mapping]]).all())

def test_label(self):
N = 16575
Expand Down

0 comments on commit 9d8dfae

Please sign in to comment.