diff --git a/CHANGELOG.md b/CHANGELOG.md index cbdee36e..74c6f7d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,21 @@ # Change Log -## [0.5.0] - -### Added +## [0.5.1] - v0.5 documentation updates - Nonlinear functionals and modules - Warning when using cuda without ME cuda support - diagnostics test +- TensorField slice + - Cache the unique map and inverse map pair in the coordinate manager +- CoordinateManager + - `field_to_sparse_insert_and_map` + - `exists_field_to_sparse` + - `get_field_to_sparse_map` +- CoordiateFieldMap + - `quantize_coordinates` + +## [0.5.0] - 2020-12-24 ## [0.5.0a] - 2020-08-05 diff --git a/MinkowskiEngine/MinkowskiCoordinateManager.py b/MinkowskiEngine/MinkowskiCoordinateManager.py index 95481593..813a66f8 100644 --- a/MinkowskiEngine/MinkowskiCoordinateManager.py +++ b/MinkowskiEngine/MinkowskiCoordinateManager.py @@ -162,6 +162,9 @@ def insert_and_map( coordinate_map_type == `CoordinateMapType.GPU`) that defines the coordinates. + :attr:`tensor_stride` (`list`): a list of `D` elements that defines the + tensor stride for the new order-`D + 1` sparse tensor. + Example:: >>> manager = CoordinateManager(D=1) @@ -178,7 +181,7 @@ def insert_and_map( def insert_field( self, coordinates: torch.Tensor, - tensor_stride: Union[int, Sequence, np.ndarray], + tensor_stride: Sequence, string_id: str = "", ) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]: r"""create a new coordinate map and returns @@ -186,6 +189,10 @@ def insert_field( :attr:`coordinates`: `torch.FloatTensor` (`CUDA` if coordinate_map_type == `CoordinateMapType.GPU`) that defines the coordinates. + :attr:`tensor_stride` (`list`): a list of `D` elements that defines the + tensor stride for the new order-`D + 1` sparse tensor. + + Example:: >>> manager = CoordinateManager(D=1) @@ -198,6 +205,44 @@ def insert_field( """ return self._manager.insert_field(coordinates, tensor_stride, string_id) + def field_to_sparse_insert_and_map( + self, + field_map_key: CoordinateMapKey, + sparse_tensor_stride: Union[int, Sequence, np.ndarray], + sparse_tensor_string_id: str = "", + ) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]: + + r"""Create a sparse tensor coordinate map with the tensor stride. + + :attr:`field_map_key` (`CoordinateMapKey`): field map that a new sparse + tensor will be created from. + + :attr:`tensor_stride` (`list`): a list of `D` elements that defines the + tensor stride for the new order-`D + 1` sparse tensor. + + :attr:`string_id` (`str`): string id of the new sparse tensor coordinate map key. + + Example:: + + >>> manager = CoordinateManager(D=1) + >>> coordinates = torch.FloatTensor([[0, 0.1], [0, 2.3], [0, 1.2], [0, 2.4]]) + >>> key, (unique_map, inverse_map) = manager.insert(coordinates, [1]) + + """ + return self._manager.field_to_sparse_insert_and_map( + field_map_key, sparse_tensor_stride, sparse_tensor_string_id + ) + + def exists_field_to_sparse( + self, field_map_key: CoordinateMapKey, sparse_map_key: CoordinateMapKey + ): + return self._manager.exists_field_to_sparse(field_map_key, sparse_map_key) + + def get_field_to_sparse_map( + self, field_map_key: CoordinateMapKey, sparse_map_key: CoordinateMapKey + ): + return self._manager.get_field_to_sparse_map(field_map_key, sparse_map_key) + def stride( self, coordinate_map_key: CoordinateMapKey, @@ -284,6 +329,9 @@ def get_unique_coordinate_map_key( """ Returns a unique coordinate_map_key for a given tensor stride. + :attr:`tensor_stride` (`list`): a list of `D` elements that defines the + tensor stride for the new order-`D + 1` sparse tensor. + """ return self._manager.get_random_string_id(tensor_stride, "") diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index 4a2433cb..ff4d47e1 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -564,7 +564,7 @@ def slice(self, X): if isinstance(X, TensorField): return TensorField( - self.F[X.inverse_mapping], + self.F[X.inverse_mapping(self.coordinate_map_key)], coordinate_field_map_key=X.coordinate_field_map_key, coordinate_manager=X.coordinate_manager, quantization_mode=X.quantization_mode, @@ -616,7 +616,9 @@ def cat_slice(self, X): from MinkowskiTensorField import TensorField - features = torch.cat((self.F[X.inverse_mapping], X.F), dim=1) + features = torch.cat( + (self.F[X.inverse_mapping(self.coordinate_map_key)], X.F), dim=1 + ) if isinstance(X, TensorField): return TensorField( features, @@ -630,7 +632,7 @@ def cat_slice(self, X): ), "Slice can only be applied on the same coordinates (coordinate_map_key)" return TensorField( features, - coordinates=self.C[X.inverse_mapping], + coordinates=self.C[X.inverse_mapping(self.coordinate_map_key)], coordinate_manager=self.coordinate_manager, quantization_mode=self.quantization_mode, ) diff --git a/MinkowskiEngine/MinkowskiTensorField.py b/MinkowskiEngine/MinkowskiTensorField.py index d952c11d..a165134c 100644 --- a/MinkowskiEngine/MinkowskiTensorField.py +++ b/MinkowskiEngine/MinkowskiTensorField.py @@ -22,8 +22,11 @@ # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part # of the code. import os -import torch +import numpy as np +from collections import Sequence +from typing import Union, List, Tuple +import torch from MinkowskiCommon import convert_to_int_list, StrideType from MinkowskiEngineBackend._C import ( GPUMemoryAllocatorType, @@ -41,6 +44,7 @@ set_global_coordinate_manager, ) from MinkowskiSparseTensor import SparseTensor +from sparse_matrix_functions import MinkowskiSPMMFunction class TensorField(Tensor): @@ -212,6 +216,7 @@ def __init__( self._C = coordinates self.coordinate_field_map_key = coordinate_field_map_key self._batch_rows = None + self._inverse_mapping = {} @property def C(self): @@ -243,29 +248,82 @@ def _batchwise_row_indices(self): def _get_coordinate_field(self): return self._manager.get_coordinate_field(self.coordinate_field_map_key) - def sparse(self, quantization_mode=None): + def sparse( + self, tensor_stride: Union[int, Sequence, np.array] = 1, quantization_mode=None + ): r"""Converts the current sparse tensor field to a sparse tensor.""" if quantization_mode is None: quantization_mode = self.quantization_mode + tensor_stride = convert_to_int_list(tensor_stride, self.D) + + sparse_tensor_key, ( + unique_index, + inverse_mapping, + ) = self._manager.field_to_sparse_insert_and_map( + self.coordinate_field_map_key, + tensor_stride, + ) + + self._inverse_mapping[sparse_tensor_key] = inverse_mapping + + if self.quantization_mode in [ + SparseTensorQuantizationMode.UNWEIGHTED_SUM, + SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, + ]: + spmm = MinkowskiSPMMFunction() + N = len(self._F) + cols = torch.arange( + N, + dtype=inverse_mapping.dtype, + device=inverse_mapping.device, + ) + vals = torch.ones(N, dtype=self._F.dtype, device=self._F.device) + size = torch.Size([len(unique_index), len(inverse_mapping)]) + features = spmm.apply(inverse_mapping, cols, vals, size, self._F) + if ( + self.quantization_mode + == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE + ): + nums = spmm.apply( + inverse_mapping, + cols, + vals, + size, + vals.reshape(N, 1), + ) + features /= nums + elif self.quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: + features = self._F[unique_index] + else: + # No quantization + raise ValueError("Invalid quantization mode") + sparse_tensor = SparseTensor( - self._F, - coordinates=self.coordinates, - quantization_mode=quantization_mode, - coordinate_manager=self.coordinate_manager, + features, + coordinate_map_key=sparse_tensor_key, + coordinate_manager=self._manager, ) - # Save the inverse mapping - self._inverse_mapping = sparse_tensor.inverse_mapping return sparse_tensor - @property - def inverse_mapping(self): - if not hasattr(self, "_inverse_mapping"): - raise ValueError( - "Did you run SparseTensor.slice? The slice must take a tensor field that returned TensorField.space." - ) - return self._inverse_mapping + def inverse_mapping(self, sparse_tensor_map_key: CoordinateMapKey): + if sparse_tensor_map_key not in self._inverse_mapping: + if not self._manager.exists_field_to_sparse( + self.coordinate_field_map_key, sparse_tensor_map_key + ): + raise ValueError( + f"The field to sparse tensor mapping does not exists for the key: {sparse_tensor_map_key}. Please run TensorField.sparse({sparse_tensor_map_key.get_tensor_stride()})" + ) + else: + # Extract the mapping + ( + _, + self._inverse_mapping[sparse_tensor_map_key], + ) = self._manager.get_field_to_sparse_map( + self.coordinate_field_map_key, sparse_tensor_map_key + ) + return self._inverse_mapping[sparse_tensor_map_key] def __repr__(self): return ( diff --git a/examples/resnet.py b/examples/resnet.py index ff40aa9c..bbe9a84e 100644 --- a/examples/resnet.py +++ b/examples/resnet.py @@ -124,7 +124,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0 return nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: ME.SparseTensor): x = self.conv1(x) x = self.layer1(x) x = self.layer2(x) @@ -185,7 +185,7 @@ def network_initialization(self, in_channels, out_channels, D): ResNetBase.network_initialization(self, field_ch2, out_channels, D) - def forward(self, x): + def forward(self, x: ME.TensorField): otensor = self.field_network(x) otensor2 = self.field_network2(otensor.cat_slice(x)) return ResNetBase.forward(self, otensor2) diff --git a/pybind/extern.hpp b/pybind/extern.hpp index 04357556..8e211b6c 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -755,6 +755,13 @@ void instantiate_manager(py::module &m, const std::string &dtypestr) { &manager_type::to_string, py::const_)) .def("insert_and_map", &manager_type::insert_and_map) .def("insert_field", &manager_type::insert_field) + .def("field_to_sparse_insert_and_map", + &manager_type::field_to_sparse_insert_and_map) + .def("exists_field_to_sparse", + py::overload_cast( + &manager_type::exists_field_to_sparse, py::const_)) + .def("get_field_to_sparse_map", &manager_type::get_field_to_sparse_map) .def("stride", &manager_type::py_stride) .def("origin", &manager_type::py_origin) .def("get_coordinates", &manager_type::get_coordinates) diff --git a/src/coordinate_map_cpu.hpp b/src/coordinate_map_cpu.hpp index 0266e848..0e77948c 100644 --- a/src/coordinate_map_cpu.hpp +++ b/src/coordinate_map_cpu.hpp @@ -188,14 +188,16 @@ std::vector interpolation_map_weight_kernel( } // namespace detail -template class TemplatedAllocator = std::allocator> class CoordinateFieldMapCPU - : public CoordinateMap { + : public CoordinateMap { // Coordinate wrapper public: - using base_type = CoordinateMap; - using self_type = CoordinateFieldMapCPU; + using base_type = CoordinateMap; + using self_type = + CoordinateFieldMapCPU; using size_type = typename base_type::size_type; using index_type = typename base_type::index_type; using stride_type = typename base_type::stride_type; @@ -218,8 +220,8 @@ class CoordinateFieldMapCPU * * @return none */ - void insert(coordinate_type const *coordinate_begin, - coordinate_type const *coordinate_end) { + void insert(coordinate_field_type const *coordinate_begin, + coordinate_field_type const *coordinate_end) { size_type N = (coordinate_end - coordinate_begin) / m_coordinate_size; base_type::allocate(N); // copy data directly to the ptr @@ -227,11 +229,61 @@ class CoordinateFieldMapCPU base_type::coordinate_data()); } - void copy_coordinates(coordinate_type *dst_coordinate) const { + using base_type::const_coordinate_data; + using base_type::coordinate_data; + + void copy_coordinates(coordinate_field_type *dst_coordinate) const { std::copy_n(base_type::const_coordinate_data(), size() * m_coordinate_size, dst_coordinate); } + void quantize_coordinates(coordinate_int_type *p_dst_coordinates, + stride_type const &tensor_stride) const { + coordinate_field_type const *const p_tfield = const_coordinate_data(); + int64_t const stride_prod = std::accumulate( + tensor_stride.begin(), tensor_stride.end(), 1, std::multiplies<>()); + ASSERT(stride_prod > 0, "Invalid stride"); + + const size_t N = omp_get_max_threads(); + const size_t stride = (size() + N - 1) / N; + LOG_DEBUG("kernel map with", N, "chunks and", stride, "stride."); + + if (stride_prod == 1) { +#pragma omp parallel for + for (uint32_t n = 0; n < N; n++) { + for (auto i = stride * n; + i < std::min((n + 1) * stride, uint64_t(size())); ++i) { + + // batch index + coordinate_int_type *p_curr_dst = + &p_dst_coordinates[i * m_coordinate_size]; + p_curr_dst[0] = std::lroundf(p_tfield[i * m_coordinate_size]); + for (uint32_t j = 1; j < m_coordinate_size; ++j) { + p_curr_dst[j] = std::floor(p_tfield[m_coordinate_size * i + j]); + } + } + } + } else { +#pragma omp parallel for + for (uint32_t n = 0; n < N; n++) { + for (auto i = stride * n; + i < std::min((n + 1) * stride, uint64_t(size())); ++i) { + + // batch index + coordinate_int_type *p_curr_dst = + &p_dst_coordinates[i * m_coordinate_size]; + p_curr_dst[0] = std::lroundf(p_tfield[i * m_coordinate_size]); + for (uint32_t j = 1; j < m_coordinate_size; ++j) { + auto const curr_tensor_stride = tensor_stride[j - 1]; + p_curr_dst[j] = curr_tensor_stride * + std::floor(p_tfield[m_coordinate_size * i + j] / + curr_tensor_stride); + } + } + } + } + } + inline size_type size() const noexcept { return m_size; } std::string to_string() const { Formatter o; diff --git a/src/coordinate_map_gpu.cu b/src/coordinate_map_gpu.cu index 1a9f625b..ddbbf899 100644 --- a/src/coordinate_map_gpu.cu +++ b/src/coordinate_map_gpu.cu @@ -98,6 +98,93 @@ insert_and_map_kernel(map_type __restrict__ map, // } // namespace detail +/* + * Field Map + */ +namespace detail { + +template +__global__ void quantize_coordinates_kernel( + coordinate_field_type const *__restrict__ p_tfield, // + coordinate_int_type *__restrict__ p_stensor, // + index_type const *__restrict__ p_tensor_stride, // + index_type const num_threads, index_type const coordinate_size) { + // coordinate_size * sizeof(index_type) + coordinate_size * sizeof(float_type) + // + THREADS * coordinate_size * sizeof(coordinate_type) + extern __shared__ index_type sh_tensor_stride[]; + + auto const tx = threadIdx.x; + auto const bx = blockIdx.x; + auto const x = blockDim.x * bx + tx; + + if (stride_one) { + if (x < num_threads) { + if (x % coordinate_size == 0) + p_stensor[x] = lrint(p_tfield[x]); + else + p_stensor[x] = floor(p_tfield[x]); + } + } else { + for (index_type i = tx; i < coordinate_size - 1; i += blockDim.x) { + sh_tensor_stride[i] = p_tensor_stride[i]; + } + + __syncthreads(); + + if (x < num_threads) { + // batch index + if (x % coordinate_size == 0) + p_stensor[x] = lrint(p_tfield[x]); + else { + index_type curr_tensor_stride = + sh_tensor_stride[((x - 1) % coordinate_size)]; + p_stensor[x] = + floor(p_tfield[x] / curr_tensor_stride) * curr_tensor_stride; + } + } + } +} +} // namespace detail + +template class TemplatedAllocator> +void CoordinateFieldMapGPU:: + quantize_coordinates(coordinate_int_type *d_dst_coordinates, + stride_type const &tensor_stride) const { + int64_t const stride_prod = std::accumulate( + tensor_stride.begin(), tensor_stride.end(), 1, std::multiplies<>()); + + // Copy tensor_stride to device + index_type *d_tensor_stride = reinterpret_cast( + m_byte_allocator.allocate(m_coordinate_size * sizeof(index_type))); + CUDA_CHECK(cudaMemcpy( + d_tensor_stride, // dst + tensor_stride.data(), // first element of the dereferenced iter. + sizeof(index_type) * m_coordinate_size, // bytes + cudaMemcpyHostToDevice)); + + size_type const num_threads = size() * m_coordinate_size; + auto const num_blocks = GET_BLOCKS(num_threads, CUDA_NUM_THREADS); + + if (stride_prod == 1) { + detail::quantize_coordinates_kernel + <<>>( + const_coordinate_data(), d_dst_coordinates, d_tensor_stride, + num_threads, m_coordinate_size); + } else { + detail::quantize_coordinates_kernel + <<>>( + const_coordinate_data(), d_dst_coordinates, d_tensor_stride, + num_threads, m_coordinate_size); + } +} + /* * @brief Given a key iterator begin-end pair and a value iterator begin-end * pair, insert all elements. @@ -265,8 +352,10 @@ stride_copy(coordinate_type const *__restrict__ src_coordinates, // auto const bx = blockIdx.x; auto const x = blockDim.x * bx + tx; - if (tx < coordinate_size - 1) - sh_stride[tx] = stride[tx]; + for (index_type i = tx; i < coordinate_size - 1; i += blockDim.x) + sh_stride[i] = stride[i]; + + __syncthreads(); if (x < num_threads) { const index_type src_start = src_valid_row_index[x] * coordinate_size; @@ -2038,13 +2127,11 @@ void CoordinateMapGPU::copy_coordinates( } // Template instantiation -template class CoordinateFieldMapGPU; -template class CoordinateFieldMapGPU; template class CoordinateFieldMapGPU; template class CoordinateFieldMapGPU; template class CoordinateMapGPU - class TemplatedAllocator = - detail::c10_allocator> +template class TemplatedAllocator = + detail::c10_allocator> class CoordinateFieldMapGPU - : public CoordinateMap { + : public CoordinateMap { // Coordinate wrapper public: - using base_type = CoordinateMap; - using self_type = CoordinateFieldMapGPU; + using base_type = CoordinateMap; + using self_type = + CoordinateFieldMapGPU; using size_type = typename base_type::size_type; using index_type = typename base_type::index_type; using stride_type = typename base_type::stride_type; @@ -69,22 +71,29 @@ public: * * @return none */ - void insert(coordinate_type const *coordinate_begin, - coordinate_type const *coordinate_end) { + void insert(coordinate_field_type const *coordinate_begin, + coordinate_field_type const *coordinate_end) { size_type N = (coordinate_end - coordinate_begin) / m_coordinate_size; base_type::allocate(N); // copy data directly to the ptr CUDA_CHECK(cudaMemcpy(base_type::coordinate_data(), coordinate_begin, - N * m_coordinate_size * sizeof(coordinate_type), + N * m_coordinate_size * sizeof(coordinate_field_type), cudaMemcpyDeviceToDevice)); } - void copy_coordinates(coordinate_type *dst_coordinate) const { - CUDA_CHECK(cudaMemcpy(dst_coordinate, base_type::const_coordinate_data(), - size() * m_coordinate_size * sizeof(coordinate_type), - cudaMemcpyDeviceToDevice)); + void copy_coordinates(coordinate_field_type *dst_coordinate) const { + CUDA_CHECK( + cudaMemcpy(dst_coordinate, base_type::const_coordinate_data(), + size() * m_coordinate_size * sizeof(coordinate_field_type), + cudaMemcpyDeviceToDevice)); } + void quantize_coordinates(coordinate_int_type *p_dst_coordinates, + stride_type const &tensor_stride) const; + + using base_type::coordinate_data; + using base_type::const_coordinate_data; + inline size_type size() const noexcept { return m_size; } std::string to_string() const { Formatter o; @@ -93,19 +102,20 @@ public: } private: + using base_type::m_byte_allocator; using base_type::m_coordinate_size; size_type m_size; }; -// clang-format off /* * Inherit from the CoordinateMap for a concurrent coordinate unordered map. */ -template class TemplatedAllocator = detail::c10_allocator> -class CoordinateMapGPU : public CoordinateMap { +template + class TemplatedAllocator = + detail::c10_allocator> +class CoordinateMapGPU + : public CoordinateMap { public: - // clang-format off using base_type = CoordinateMap; using self_type = CoordinateMapGPU; diff --git a/src/coordinate_map_manager.cpp b/src/coordinate_map_manager.cpp index 9206a1b0..88786eaf 100644 --- a/src/coordinate_map_manager.cpp +++ b/src/coordinate_map_manager.cpp @@ -45,93 +45,6 @@ default_types::stride_type zeros(size_t const len) { return _fill_vec<0>(len); } default_types::stride_type ones(size_t const len) { return _fill_vec<1>(len); } } // namespace detail -/* - -template -vector -CoordsManager::getCoordsMap(py::object py_in_coords_key, - py::object py_out_coords_key) const { - CoordsKey *p_in_coords_key = py_in_coords_key.cast(); - CoordsKey *p_out_coords_key = py_out_coords_key.cast(); - const uint64_t in_coords_key = p_in_coords_key->getKey(); - const uint64_t out_coords_key = p_out_coords_key->getKey(); - - const auto in_map_iter = coords_maps.find(in_coords_key); - const auto out_map_iter = coords_maps.find(out_coords_key); - - ASSERT(in_map_iter != coords_maps.end(), "Input coords not found at", - to_string(in_coords_key)); - ASSERT(out_map_iter != coords_maps.end(), "Output coords not found at", - to_string(out_coords_key)); - - const auto &out_tensor_strides = p_out_coords_key->getTensorStride(); - const auto in_out = - in_map_iter->second.stride_map(out_map_iter->second, out_tensor_strides); - - const auto &ins = in_out.first; - const auto &outs = in_out.second; - // All size - const auto N = std::accumulate(ins.begin(), ins.end(), 0, - [](size_t curr_sum, const vector &map) { - return curr_sum + map.size(); - }); - - at::Tensor in_out_1 = - torch::empty({N}, torch::TensorOptions().dtype(torch::kInt64)); - at::Tensor in_out_2 = - torch::empty({N}, torch::TensorOptions().dtype(torch::kInt64)); - - auto a_in_out_1 = in_out_1.accessor(); - auto a_in_out_2 = in_out_2.accessor(); - - size_t curr_it = 0; - for (const auto &in : ins) - for (const auto i : in) - a_in_out_1[curr_it++] = i; - - curr_it = 0; - for (const auto &out : outs) - for (const auto o : out) - a_in_out_2[curr_it++] = o; - - return {in_out_1, in_out_2}; -} - -// Generate and return the ins -> out map. -template -pair, vector> -CoordsManager::getUnionMap(vector py_in_coords_keys, - py::object py_out_coords_key) { - - // all exception handling will be done inside the following - const InOutMapsRefPair in_outs = - getUnionInOutMaps(py_in_coords_keys, py_out_coords_key); - const auto &ins = in_outs.first; - const auto &outs = in_outs.second; - - // Size of the in out maps - const auto N = ins.size(); - - // Return torch tensor - vector th_ins; - vector th_outs; - for (size_t i = 0; i < N; ++i) { - at::Tensor th_in = torch::empty( - {(long)ins[i].size()}, torch::TensorOptions().dtype(torch::kInt64)); - at::Tensor th_out = torch::empty( - {(long)outs[i].size()}, torch::TensorOptions().dtype(torch::kInt64)); - - copy_types(ins[i], th_in); - copy_types(outs[i], th_out); - - th_ins.push_back(move(th_in)); - th_outs.push_back(move(th_out)); - } - - return make_pair(th_ins, th_outs); -} - -*/ /******************************* * Initialization @@ -189,7 +102,8 @@ struct insert_and_map_functor struct insert_field_functor< coordinate_type, coordinate_field_type, std::allocator, CoordinateMapCPU, - CoordinateFieldMapCPU> { + CoordinateFieldMapCPU> { void operator()(coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, @@ -200,8 +114,9 @@ struct insert_field_functor< uint32_t const coordinate_size = th_coordinate.size(1); coordinate_field_type *p_coordinate = th_coordinate.data_ptr(); - auto map = CoordinateFieldMapCPU( - N, coordinate_size, map_key.first); + auto map = CoordinateFieldMapCPU(N, coordinate_size, + map_key.first); map.insert(p_coordinate, p_coordinate + N * coordinate_size); LOG_DEBUG("insert map with tensor_stride", map_key.first); @@ -270,15 +185,107 @@ py::object CoordinateMapManager class TemplatedAllocator, + template class A> + class CoordinateMapType> +std::pair> +CoordinateMapManager:: + field_to_sparse_insert_and_map( + CoordinateMapKey const *p_in_field_map_key, + default_types::stride_type const sparse_tensor_stride, + std::string const sparse_tensor_string_id) { + auto const coordinate_size = p_in_field_map_key->get_coordinate_size(); + // Basic assertions + ASSERT(coordinate_size - 1 == sparse_tensor_stride.size(), + "The coordinate dimension (coordinate_size - 1):", coordinate_size - 1, + " must match the size of tensor stride:", + ArrToString(sparse_tensor_stride)); + + // Find coordinate field + auto const it = m_field_coordinates.find(p_in_field_map_key->get_key()); + ASSERT(it != m_field_coordinates.end(), ERROR_MAP_NOT_FOUND); + auto const &field_map = it->second; + auto const nrows = field_map.size(); + auto const ncols = field_map.coordinate_size(); + + auto options = torch::TensorOptions().dtype(torch::kInt).requires_grad(false); + + if (!detail::is_cpu_coordinate_map::value) { +#ifndef CPU_ONLY + auto device_id = at::cuda::current_device(); + options = options.device(torch::kCUDA, device_id); +#else + ASSERT(false, ERROR_CPU_ONLY); +#endif + } + + // generate the map_key + coordinate_map_key_type map_key = + std::make_pair(sparse_tensor_stride, sparse_tensor_string_id); + if (m_coordinate_maps.find(map_key) != m_coordinate_maps.end()) { + LOG_DEBUG("CoordinateMapKey collision detected:", map_key, + "generating new string id."); + map_key = + get_random_string_id(sparse_tensor_stride, sparse_tensor_string_id); + } + + LOG_DEBUG("initializing a field map with tensor stride:", map_key.first, + "string id:", map_key.second); + + // Quantize the field with tensor stride. + // The coordinate must be a tensor. Wrap a pointer with a tensor. + at::Tensor int_coordinates = + at::empty({field_map.size(), coordinate_size}, options); + field_map.quantize_coordinates(int_coordinates.data_ptr(), + sparse_tensor_stride); + + auto const map_inverse_map = + detail::insert_and_map_functor()( + map_key, int_coordinates, *this); + + auto const field_to_sparse_map_key = + std::pair{ + p_in_field_map_key->get_key(), map_key}; + + auto result = m_field_to_sparse_maps.insert( + std::pair< + const std::pair, + const std::pair>{field_to_sparse_map_key, + map_inverse_map}); + LOG_DEBUG("field to sparse tensor map insertion", result.second); + + py::object py_key = py::cast(new CoordinateMapKey(coordinate_size, map_key)); + + return std::make_pair(py_key, map_inverse_map); +} +template class TemplatedAllocator, + template class A> + class CoordinateMapType> +std::pair +CoordinateMapManager:: + get_field_to_sparse_map(CoordinateMapKey const *p_field_key, + CoordinateMapKey const *p_sparse_key) const { + auto key = std::pair{ + p_field_key->get_key(), p_sparse_key->get_key()}; + auto it = m_field_to_sparse_maps.find(key); + ASSERT(it != m_field_to_sparse_maps.end(), + "Field To Sparse Map doesn't exist"); + return it->second; +} + /* * coords: coordinates in IntTensor - * mapping: output mapping in IntTensor * tensor_strides: current tensor strides this coords will be initializeds - * force_creation: even when there's a duplicate coords with the same tensor - * strides. - * force_remap: if there's duplicate coords, remap - * allow_duplicate_coords: create map when there are duplicates in the - * coordinates */ template class TemplatedAllocator, diff --git a/src/coordinate_map_manager.cu b/src/coordinate_map_manager.cu index 608c25bc..993293d4 100644 --- a/src/coordinate_map_manager.cu +++ b/src/coordinate_map_manager.cu @@ -116,7 +116,8 @@ template > { + CoordinateFieldMapGPU> { void operator()( coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, @@ -127,8 +128,9 @@ struct insert_field_functor< uint32_t const coordinate_size = th_coordinate.size(1); coordinate_field_type *p_coordinate = th_coordinate.data_ptr(); - auto map = CoordinateFieldMapGPU( - N, coordinate_size, map_key.first); + auto map = CoordinateFieldMapGPU(N, coordinate_size, + map_key.first); map.insert(p_coordinate, p_coordinate + N * coordinate_size); LOG_DEBUG("insert map with tensor_stride", map_key.first); diff --git a/src/coordinate_map_manager.hpp b/src/coordinate_map_manager.hpp index 0036ac2c..ee67cb5d 100644 --- a/src/coordinate_map_manager.hpp +++ b/src/coordinate_map_manager.hpp @@ -93,11 +93,14 @@ class CoordinateMapManager { #ifndef CPU_ONLY using field_map_type = typename std::conditional< detail::is_cpu_coordinate_map::value, - CoordinateFieldMapCPU, - CoordinateFieldMapGPU>::type; + CoordinateFieldMapCPU, + CoordinateFieldMapGPU>::type; #else using field_map_type = - CoordinateFieldMapCPU; + CoordinateFieldMapCPU; #endif using self_type = CoordinateMapManager; @@ -160,6 +163,16 @@ class CoordinateMapManager { stride_type const tensor_stride, std::string const string_id = ""); + /* + * New coordinate map initialzation function. + * + * returns key and map, inverse map + */ + std::pair> + field_to_sparse_insert_and_map(CoordinateMapKey const *p_in_field_map_key, + stride_type const sparse_tensor_stride, + std::string const sparse_string_id = ""); + /* * New coordinate map initialzation function. * @@ -261,6 +274,14 @@ class CoordinateMapManager { return m_field_coordinates.find(key) != m_field_coordinates.end(); } + inline bool exists_field_to_sparse( + coordinate_map_key_type const &field_key, + coordinate_map_key_type const &sparse_key) const noexcept { + auto key = std::pair{ + field_key, sparse_key}; + return m_field_to_sparse_maps.find(key) != m_field_to_sparse_maps.end(); + } + // when the key is the python coordinate map key inline bool exists(CoordinateMapKey const *p_key) const { // key set exception @@ -273,6 +294,14 @@ class CoordinateMapManager { return exists_field(p_key->get_key()); } + inline bool + exists_field_to_sparse(CoordinateMapKey const *p_field_key, + CoordinateMapKey const *p_sparse_key) const { + // key set exception + return exists_field_to_sparse(p_field_key->get_key(), + p_sparse_key->get_key()); + } + inline size_type size(coordinate_map_key_type const &key) const { auto it = m_coordinate_maps.find(key); ASSERT(it != m_coordinate_maps.end(), ERROR_MAP_NOT_FOUND); @@ -293,6 +322,10 @@ class CoordinateMapManager { at::Tensor get_coordinate_field(CoordinateMapKey const *p_key) const; + std::pair + get_field_to_sparse_map(CoordinateMapKey const *p_field_key, + CoordinateMapKey const *p_sparse_key) const; + std::vector get_coordinate_map_keys(stride_type const tensor_stride) const { std::vector keys; @@ -462,6 +495,12 @@ class CoordinateMapManager { kernel_map_key_hasher> m_kernel_maps; + std::unordered_map< + const std::pair, + const std::pair, + field_to_sparse_map_key_hasher> + m_field_to_sparse_maps; + #ifndef CPU_ONLY TemplatedAllocator m_allocator; #endif diff --git a/src/types.hpp b/src/types.hpp index 3d6fbee1..e216d828 100644 --- a/src/types.hpp +++ b/src/types.hpp @@ -83,7 +83,7 @@ struct coordinate_map_key_hasher { result_type operator()(coordinate_map_key_type const &key) const { auto hash = robin_hood::hash_bytes( key.first.data(), sizeof(default_types::size_type) * key.first.size()); - hash += std::hash{}(key.second); + hash ^= std::hash{}(key.second); return hash; } }; @@ -169,7 +169,6 @@ enum Type { }; } - /* Key for KernelMap * * A tuple of (CoordinateMapKey (input), @@ -229,6 +228,18 @@ struct kernel_map_key_hasher { } }; +template +struct field_to_sparse_map_key_hasher { + using result_type = size_t; + + result_type operator()(std::pair const &key) const { + result_type hash = hasher{}(key.first); + hash ^= hasher{}(key.second); + return hash; + } +}; + } // end namespace minkowski #endif // TYPES_HPP diff --git a/tests/python/tensor_field.py b/tests/python/tensor_field.py index f5a9d78f..c48e2db5 100644 --- a/tests/python/tensor_field.py +++ b/tests/python/tensor_field.py @@ -41,7 +41,7 @@ def test(self): [[0, 1], [0, 1], [0, 2], [0, 2], [1, 0], [1, 0], [1, 1]] ) feats = torch.FloatTensor([[0, 1, 2, 3, 5, 6, 7]]).T - sfield = TensorField(feats, coords, device=feats.device) + sfield = TensorField(feats, coords) # Convert to a sparse tensor stensor = sfield.sparse( @@ -52,6 +52,21 @@ def test(self): {0.5, 2.5, 5.5, 7} == {a for a in stensor.F.squeeze().detach().numpy()} ) + # device cuda + if not torch.cuda.is_available(): + return + + sfield = TensorField(feats, coords, device="cuda") + + # Convert to a sparse tensor + stensor = sfield.sparse( + quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE + ) + print(stensor) + self.assertTrue( + {0.5, 2.5, 5.5, 7} == {a for a in stensor.F.squeeze().detach().cpu().numpy()} + ) + def test_pcd(self): coords, colors, pcd = load_file("1.ply") voxel_size = 0.02