diff --git a/MinkowskiEngine/utils/quantization.py b/MinkowskiEngine/utils/quantization.py index bfb6bf5c..2286b076 100644 --- a/MinkowskiEngine/utils/quantization.py +++ b/MinkowskiEngine/utils/quantization.py @@ -63,6 +63,30 @@ def ravel_hash_vec(arr): def quantize(coords): + r"""Returns a unique index map and an inverse index map. + + Args: + :attr:`coords` (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a + matrix of size :math:`N \times D` where :math:`N` is the number of + points in the :math:`D` dimensional space. + + Returns: + :attr:`unique_map` (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a + list of indices that defines unique coordinates. + :attr:`coords[unique_map]` is the unique coordinates. + + :attr:`inverse_map` (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a + list of indices that defines the inverse map that recovers the original + coordinates. :attr:`coords[unique_map[inverse_map]] == coords` + + Example:: + + >>> unique_map, inverse_map = quantize(coords) + >>> unique_coords = coords[unique_map] + >>> print(unique_coords[inverse_map] == coords) # True, ..., True + >>> print(coords[unique_map[inverse_map]] == coords) # True, ..., True + + """ assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \ "Invalid coords type" if isinstance(coords, np.ndarray): @@ -92,6 +116,7 @@ def sparse_quantize(coords, labels=None, ignore_label=-100, return_index=False, + return_inverse=False, quantization_size=None): r"""Given coordinates, and features (optionally labels), the function generates quantized (voxelized) coordinates. @@ -117,8 +142,18 @@ def sparse_quantize(coords, IGNORE LABEL. :attr:`torch.nn.CrossEntropyLoss(ignore_index=ignore_label)` - :attr:`return_index` (:attr:`bool`, optional): True if you want the indices of the - quantized coordinates. False by default. + :attr:`return_index` (:attr:`bool`, optional): set True if you want the + indices of the quantized coordinates. False by default. + + :attr:`return_inverse` (:attr:`bool`, optional): set True if you want + the indices that can recover the discretized original coordinates. + False by default. `return_index` must be True when `return_reverse` is True. + + Example:: + + >>> unique_map, inverse_map = sparse_quantize(discrete_coords, return_index=True, return_inverse=True) + >>> unique_coords = discrete_coords[unique_map] + >>> print(unique_coords[inverse_map] == discrete_coords) # True :attr:`quantization_size` (:attr:`float`, :attr:`list`, or :attr:`numpy.ndarray`, optional): the length of the each side of the @@ -150,6 +185,9 @@ def sparse_quantize(coords, assert coords.ndim == 2, \ "The coordinates must be a 2D matrix. The shape of the input is " + str(coords.shape) + if return_inverse: + assert return_index, "return_reverse must be set with return_index" + if use_feat: assert feats.ndim == 2 assert coords.shape[0] == feats.shape[0] @@ -202,24 +240,14 @@ def sparse_quantize(coords, return discrete_coords[mapping], colabels else: - mapping = quantize(discrete_coords) - if len(mapping) > 0: - if return_index: - return mapping + unique_map, inverse_map = quantize(discrete_coords) + if return_index: + if return_inverse: + return unique_map, inverse_map else: - if use_feat: - return discrete_coords[mapping], feats[mapping] - else: - return discrete_coords[mapping] - + return unique_map else: - if return_index: - if isinstance(discrete_coords, np.ndarray): - return np.arange(len(discrete_coords)) - else: - return torch.arange(len(discrete_coords), dtype=torch.long) + if use_feat: + return discrete_coords[unique_map], feats[unique_map] else: - if use_feat: - return discrete_coords, feats - else: - return discrete_coords + return discrete_coords[unique_map] diff --git a/pybind/extern.hpp b/pybind/extern.hpp index c4c5699b..31a44add 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -380,7 +380,7 @@ UnionBackwardGPU(at::Tensor grad_out_feat, vector py_in_coords_keys, /************************************* * Quantization *************************************/ -vector +vector quantize_np(py::array_t coords); vector quantize_label_np( @@ -388,7 +388,7 @@ vector quantize_label_np( py::array_t labels, int invalid_label); -at::Tensor quantize_th(at::Tensor coords); +vector quantize_th(at::Tensor coords); vector quantize_label_th(at::Tensor coords, at::Tensor labels, int invalid_label); diff --git a/src/coordsmap.cpp b/src/coordsmap.cpp index 2013cb2b..5efa6a0d 100644 --- a/src/coordsmap.cpp +++ b/src/coordsmap.cpp @@ -104,6 +104,47 @@ CoordsMap::initialize_batch(const int *p_coords, const int nrows_, return make_pair(mapping, batch_indices); } +// index, inverse_index = initialize_with_inverse(coords) +// unique_coords = coords[index] +// coords == unique_coords[inverse_index] +// coords == coords[index[inverse_index]] +tuple, vector, set> +CoordsMap::initialize_batch_with_inverse(const int *p_coords, const int nrows_, + const int ncols_) { + nrows = nrows_; + ncols = ncols_; + + vector mapping, inverse_mapping; + set batch_indices; + + mapping.reserve(nrows); + inverse_mapping.reserve(nrows); + + int c = 0; + for (int i = 0; i < nrows; i++) { + vector coord(ncols); + std::copy_n(p_coords + i * ncols, ncols, coord.data()); + + auto iter = map.find(coord); + if (iter == map.end()) { + mapping.push_back(i); + inverse_mapping.push_back(c); + +#ifdef BATCH_FIRST + batch_indices.insert(coord[0]); +#else + batch_indices.insert(coord[ncols - 1]); +#endif + map[move(coord)] = c++; + } else { + inverse_mapping.push_back(iter->second); + } + } + + return std::make_tuple(move(mapping), move(inverse_mapping), + move(batch_indices)); +} + CoordsMap CoordsMap::stride(const vector &tensor_strides) const { ASSERT(tensor_strides.size() == ncols - 1, "Invalid tensor strides"); @@ -175,7 +216,7 @@ CoordsMap::union_coords(const vector> &maps) { const auto max_index = std::distance(maps.begin(), max_iter); // Initialize with the largest coords map. - const CoordsMap& max_map = maps[max_index]; + const CoordsMap &max_map = maps[max_index]; CoordsMap out_map(max_map); out_map.reserve(num_tot); size_t c = max_map.size(); diff --git a/src/coordsmap.hpp b/src/coordsmap.hpp index f3c112d6..d24abd93 100644 --- a/src/coordsmap.hpp +++ b/src/coordsmap.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include "3rdparty/robin_hood.h" @@ -38,6 +39,7 @@ namespace minkowski { using std::reference_wrapper; using std::set; +using std::tuple; using std::vector; template struct byte_hash_vec { @@ -90,6 +92,10 @@ class CoordsMap { const int ncols_, const bool force_remap = false); + tuple, vector, set> + initialize_batch_with_inverse(const int *p_coords_, const int nrows_, + const int ncols_); + // Generate strided version of the input coordinate map. // returns mapping: out_coord row index to in_coord row index CoordsMap stride(const vector &tensor_strides) const; diff --git a/src/quantization.cpp b/src/quantization.cpp index 269aef2c..bf90848d 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -45,7 +45,7 @@ struct IndexLabel { using CoordsLabelMap = robin_hood::unordered_flat_map, IndexLabel, byte_hash_vec>; -vector quantize_np( +vector quantize_np( py::array_t coords) { py::buffer_info coords_info = coords.request(); auto &shape = coords_info.shape; @@ -58,13 +58,29 @@ vector quantize_np( // Create coords map CoordsMap map; - vector mapping = map.initialize(p_coords, nrows, ncols, false); + auto results = map.initialize_batch_with_inverse(p_coords, nrows, ncols); + auto &mapping = std::get<0>(results); + auto &inverse_mapping = std::get<1>(results); + + // Copy the concurrent vector to std vector + py::array_t py_mapping = py::array_t(mapping.size()); + py::array_t py_inverse_mapping = + py::array_t(inverse_mapping.size()); + + py::buffer_info py_mapping_info = py_mapping.request(); + py::buffer_info py_inverse_mapping_info = py_inverse_mapping.request(); + int *p_py_mapping = (int *)py_mapping_info.ptr; + int *p_py_inverse_mapping = (int *)py_inverse_mapping_info.ptr; + + std::copy_n(mapping.data(), mapping.size(), p_py_mapping); + std::copy_n(inverse_mapping.data(), inverse_mapping.size(), + p_py_inverse_mapping); // mapping is empty when coords are all unique - return mapping; + return {py_mapping, py_inverse_mapping}; } -at::Tensor quantize_th(at::Tensor coords) { +vector quantize_th(at::Tensor coords) { ASSERT(coords.dtype() == torch::kInt32, "Coordinates must be an int type tensor."); ASSERT(coords.dim() == 2, @@ -72,20 +88,28 @@ at::Tensor quantize_th(at::Tensor coords) { coords.dim(), "!= 2."); CoordsMap map; - vector mapping = - map.initialize(coords.data(), coords.size(0), coords.size(1), false); + auto results = map.initialize_batch_with_inverse( + coords.data(), coords.size(0), coords.size(1)); + auto mapping = std::get<0>(results); + auto inverse_mapping = std::get<1>(results); // Long tensor for for easier indexing auto th_mapping = torch::empty({(long)mapping.size()}, torch::TensorOptions().dtype(torch::kInt64)); + auto th_inverse_mapping = + torch::empty({(long)inverse_mapping.size()}, + torch::TensorOptions().dtype(torch::kInt64)); auto a_th_mapping = th_mapping.accessor(); + auto a_th_inverse_mapping = th_inverse_mapping.accessor(); // Copy the output for (size_t i = 0; i < mapping.size(); ++i) a_th_mapping[i] = mapping[i]; + for (size_t i = 0; i < inverse_mapping.size(); ++i) + a_th_inverse_mapping[i] = inverse_mapping[i]; // mapping is empty when coords are all unique - return th_mapping; + return {th_mapping, th_inverse_mapping}; } vector quantize_label_np( diff --git a/tests/quantization.py b/tests/quantization.py index ee48fcb2..2afc4bdd 100644 --- a/tests/quantization.py +++ b/tests/quantization.py @@ -52,11 +52,17 @@ def test(self): def test_mapping(self): N = 16575 coords = (np.random.rand(N, 3) * 100).astype(np.int32) - mapping = MEB.quantize_np(coords) + mapping, inverse_mapping = MEB.quantize_np(coords) print('N unique:', len(mapping), 'N:', N) + self.assertTrue((coords == coords[mapping[inverse_mapping]]).all()) - mapping = MEB.quantize_th(torch.from_numpy(coords)) + coords = torch.from_numpy(coords) + mapping, inverse_mapping = MEB.quantize_th(coords) print('N unique:', len(mapping), 'N:', N) + self.assertTrue((coords == coords[mapping[inverse_mapping]]).all()) + + index, reverse_index = sparse_quantize(coords, return_index=True, return_inverse=True) + self.assertTrue((coords == coords[mapping[inverse_mapping]]).all()) def test_label(self): N = 16575