From d30e6c7ed70c0862da93480cf7a0e475f92ccbd9 Mon Sep 17 00:00:00 2001 From: Chris Choy Date: Thu, 20 Aug 2020 15:07:29 -0700 Subject: [PATCH] field map insertion --- MinkowskiEngine/MinkowskiCoordinateManager.py | 32 ++- MinkowskiEngine/MinkowskiNonlinearity.py | 5 +- MinkowskiEngine/MinkowskiNormalization.py | 5 +- MinkowskiEngine/MinkowskiOps.py | 5 +- MinkowskiEngine/MinkowskiSparseTensor.py | 55 +++-- MinkowskiEngine/MinkowskiTensorField.py | 24 +- pybind/extern.hpp | 1 + src/coordinate_map_cpu.hpp | 22 +- src/coordinate_map_gpu.cu | 16 +- src/coordinate_map_gpu.cuh | 14 +- src/coordinate_map_manager.cpp | 233 ++++++++++++++---- src/coordinate_map_manager.cu | 49 +++- src/coordinate_map_manager.hpp | 66 +++-- tests/python/tensor_field.py | 4 +- 14 files changed, 390 insertions(+), 141 deletions(-) diff --git a/MinkowskiEngine/MinkowskiCoordinateManager.py b/MinkowskiEngine/MinkowskiCoordinateManager.py index 7eae631c..e810a018 100644 --- a/MinkowskiEngine/MinkowskiCoordinateManager.py +++ b/MinkowskiEngine/MinkowskiCoordinateManager.py @@ -147,14 +147,15 @@ def __init__( def insert_and_map( self, - coordinates: torch.IntTensor, - tensor_stride: Union[int, Sequence, np.ndarray, torch.Tensor], + coordinates: torch.Tensor, + tensor_stride: Union[int, Sequence, np.ndarray], string_id: str = "", ) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]: r"""create a new coordinate map and returns - :attr:`coordinates`: `torch.IntTensor` (`CUDA` if coordinate_map_type - == `CoordinateMapType.GPU`) that defines the coordinates. + :attr:`coordinates`: `torch.Tensor` (Int tensor. `CUDA` if + coordinate_map_type == `CoordinateMapType.GPU`) that defines the + coordinates. Example:: @@ -168,6 +169,29 @@ def insert_and_map( """ return self._manager.insert_and_map(coordinates, tensor_stride, string_id) + def insert_field( + self, + coordinates: torch.Tensor, + tensor_stride: Union[int, Sequence, np.ndarray], + string_id: str = "", + ) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]: + r"""create a new coordinate map and returns + + :attr:`coordinates`: `torch.FloatTensor` (`CUDA` if coordinate_map_type + == `CoordinateMapType.GPU`) that defines the coordinates. + + Example:: + + >>> manager = CoordinateManager(D=1) + >>> coordinates = torch.FloatTensor([[0, 0.1], [0, 2.3], [0, 1.2], [0, 2.4]]) + >>> key, (unique_map, inverse_map) = manager.insert(coordinates, [1]) + >>> print(key) # key is tensor_stride, string_id [1]:"" + >>> torch.all(coordinates[unique_map] == manager.get_coordinates(key)) # True + >>> torch.all(coordinates == coordinates[unique_map][inverse_map]) # True + + """ + return self._manager.insert_field(coordinates, tensor_stride, string_id) + def stride( self, coordinate_map_key: CoordinateMapKey, diff --git a/MinkowskiEngine/MinkowskiNonlinearity.py b/MinkowskiEngine/MinkowskiNonlinearity.py index 32ad4e8a..fe75e855 100644 --- a/MinkowskiEngine/MinkowskiNonlinearity.py +++ b/MinkowskiEngine/MinkowskiNonlinearity.py @@ -38,15 +38,16 @@ def __init__(self, *args, **kwargs): def forward(self, input): output = self.module(input.F) if isinstance(input, TensorField): - return input.__class__( + return TensorField( output, coordinate_map_key=input.coordinate_map_key, + coordinate_field_map_key=input.coordinate_field_map_key, coordinate_manager=input.coordinate_manager, inverse_mapping=input.inverse_mapping, quantization_mode=input.quantization_mode, ) else: - return input.__class__( + return SparseTensor( output, coordinate_map_key=input.coordinate_map_key, coordinate_manager=input.coordinate_manager, diff --git a/MinkowskiEngine/MinkowskiNormalization.py b/MinkowskiEngine/MinkowskiNormalization.py index f356e79a..e2e27740 100644 --- a/MinkowskiEngine/MinkowskiNormalization.py +++ b/MinkowskiEngine/MinkowskiNormalization.py @@ -67,15 +67,16 @@ def __init__( def forward(self, input): output = self.bn(input.F) if isinstance(input, TensorField): - return input.__class__( + return TensorField( output, coordinate_map_key=input.coordinate_map_key, + coordinate_field_map_key=input.coordinate_field_map_key, coordinate_manager=input.coordinate_manager, inverse_mapping=input.inverse_mapping, quantization_mode=input.quantization_mode, ) else: - return input.__class__( + return SparseTensor( output, coordinate_map_key=input.coordinate_map_key, coordinate_manager=input.coordinate_manager, diff --git a/MinkowskiEngine/MinkowskiOps.py b/MinkowskiEngine/MinkowskiOps.py index 5d9f1f66..16230c06 100644 --- a/MinkowskiEngine/MinkowskiOps.py +++ b/MinkowskiEngine/MinkowskiOps.py @@ -43,15 +43,16 @@ def __init__(self, in_features, out_features, bias=True): def forward(self, input: Union[SparseTensor, TensorField]): output = self.linear(input.F) if isinstance(input, TensorField): - return input.__class__( + return TensorField( output, coordinate_map_key=input.coordinate_map_key, + coordinate_field_map_key=input.coordinate_field_map_key, coordinate_manager=input.coordinate_manager, inverse_mapping=input.inverse_mapping, quantization_mode=input.quantization_mode, ) else: - return input.__class__( + return SparseTensor( output, coordinate_map_key=input.coordinate_map_key, coordinate_manager=input.coordinate_manager, diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index 61f1c9aa..9954d1f3 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -384,13 +384,24 @@ def slice(self, X, slicing_mode=0): ), "Slice can only be applied on the same coordinates (coordinate_map_key)" from MinkowskiTensorField import TensorField - return TensorField( - self.F[X.inverse_mapping], - coordinate_map_key=self.coordinate_map_key, - coordinate_manager=self.coordinate_manager, - inverse_mapping=X.inverse_mapping, - quantization_mode=X.quantization_mode, - ) + if isinstance(X, TensorField): + return TensorField( + self.F[X.inverse_mapping], + coordinate_map_key=X.coordinate_map_key, + coordinate_field_map_key=X.coordinate_field_map_key, + coordinate_manager=X.coordinate_manager, + inverse_mapping=X.inverse_mapping, + quantization_mode=X.quantization_mode, + ) + else: + return TensorField( + self.F[X.inverse_mapping], + coordinates=self.C[X.inverse_mapping], + coordinate_map_key=X.coordinate_map_key, + coordinate_manager=X.coordinate_manager, + inverse_mapping=X.inverse_mapping, + quantization_mode=X.quantization_mode, + ) def cat_slice(self, X, slicing_mode=0): r""" @@ -428,13 +439,25 @@ def cat_slice(self, X, slicing_mode=0): ), "Slice can only be applied on the same coordinates (coordinate_map_key)" from MinkowskiTensorField import TensorField - return TensorField( - torch.cat((self.F[X.inverse_mapping], X.F), dim=1), - coordinate_map_key=self.coordinate_map_key, - coordinate_manager=self.coordinate_manager, - inverse_mapping=X.inverse_mapping, - quantization_mode=X.quantization_mode, - ) + features = torch.cat((self.F[X.inverse_mapping], X.F), dim=1) + if isinstance(X, TensorField): + return TensorField( + features, + coordinate_map_key=X.coordinate_map_key, + coordinate_field_map_key=X.coordinate_field_map_key, + coordinate_manager=X.coordinate_manager, + inverse_mapping=X.inverse_mapping, + quantization_mode=X.quantization_mode, + ) + else: + return TensorField( + features, + coordinates=self.C[X.inverse_mapping], + coordinate_map_key=X.coordinate_map_key, + coordinate_manager=X.coordinate_manager, + inverse_mapping=X.inverse_mapping, + quantization_mode=X.quantization_mode, + ) def features_at_coords(self, query_coords: torch.Tensor): r"""Extract features at the specified coordinate matrix. @@ -501,7 +524,9 @@ def _get_coordinate_map_key( ( coordinate_map_key, (unique_index, inverse_mapping), - ) = input._manager.insert_and_map(coordinates, *coordinate_map_key.get_key()) + ) = input._manager.insert_and_map( + coordinates, *coordinate_map_key.get_key() + ) elif isinstance(coordinates, SparseTensor): coordinate_map_key = coordinates.coordinate_map_key else: # CoordinateMapKey type due to the previous assertion diff --git a/MinkowskiEngine/MinkowskiTensorField.py b/MinkowskiEngine/MinkowskiTensorField.py index d6e9a5e1..1a365a8d 100644 --- a/MinkowskiEngine/MinkowskiTensorField.py +++ b/MinkowskiEngine/MinkowskiTensorField.py @@ -48,6 +48,7 @@ def __init__( # optional coordinate related arguments tensor_stride: StrideType = 1, coordinate_map_key: CoordinateMapKey = None, + coordinate_field_map_key: CoordinateMapKey = None, coordinate_manager: CoordinateManager = None, quantization_mode: SparseTensorQuantizationMode = SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, # optional manager related arguments @@ -80,10 +81,20 @@ def __init__( if inverse_mapping is not None: self.inverse_mapping = inverse_mapping + self.coordinate_field_map_key = coordinate_field_map_key + if coordinate_field_map_key is None: + assert coordinates is not None + self._CC = coordinates.float() + self.coordinate_field_map_key = self._manager.insert_field( + self._CC, *self.coordinate_map_key.get_key() + ) + def initialize_coordinates(self, coordinates, features, coordinate_map_key): - self._CC = coordinates - assert not isinstance(coordinates, (torch.IntTensor, torch.cuda.IntTensor)) - int_coordinates = torch.floor(coordinates).int() + + if not isinstance(coordinates, (torch.IntTensor, torch.cuda.IntTensor)): + int_coordinates = torch.floor(coordinates).int() + else: + int_coordinates = coordinates ( self.coordinate_map_key, @@ -137,11 +148,11 @@ def coordinates(self): different instances in a batch. """ if self._CC is None: - self._CC = self._get_continuous_coordinates() + self._CC = self._get_coordinate_field() return self._CC - def _get_continuous_coordinates(self): - return self._manager.get_continuous_coordinates(self.coordinate_map_key) + def _get_coordinate_field(self): + return self._manager.get_coordinate_field(self.coordinate_field_map_key) def sparse(self): r"""Converts the current sparse tensor field to a sparse tensor.""" @@ -172,6 +183,7 @@ def sparse(self): "_F", "_D", "coordinate_map_key", + "coordinate_field_map_key", "_manager", "unique_index", "inverse_mapping", diff --git a/pybind/extern.hpp b/pybind/extern.hpp index af7dfd8d..2e96f635 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -564,6 +564,7 @@ void instantiate_manager(py::module &m, const std::string &dtypestr) { py::overload_cast( &manager_type::to_string, py::const_)) .def("insert_and_map", &manager_type::insert_and_map) + .def("insert_field", &manager_type::insert_field) .def("stride", &manager_type::py_stride) .def("origin", &manager_type::py_origin) .def("get_coordinates", &manager_type::get_coordinates) diff --git a/src/coordinate_map_cpu.hpp b/src/coordinate_map_cpu.hpp index 98c083a7..3daa5cf3 100644 --- a/src/coordinate_map_cpu.hpp +++ b/src/coordinate_map_cpu.hpp @@ -34,23 +34,23 @@ namespace minkowski { template class TemplatedAllocator = std::allocator> -class CoordinatesCPU +class CoordinateFieldMapCPU : public CoordinateMap { // Coordinate wrapper public: using base_type = CoordinateMap; - using self_type = CoordinatesCPU; + using self_type = CoordinateFieldMapCPU; using size_type = typename base_type::size_type; using index_type = typename base_type::index_type; using stride_type = typename base_type::stride_type; using byte_allocator_type = TemplatedAllocator; public: - CoordinatesCPU() = delete; - CoordinatesCPU(size_type const number_of_coordinates, - size_type const coordinate_size, - stride_type const &stride = {1}, - byte_allocator_type alloc = byte_allocator_type()) + CoordinateFieldMapCPU() = delete; + CoordinateFieldMapCPU(size_type const number_of_coordinates, + size_type const coordinate_size, + stride_type const &stride = {1}, + byte_allocator_type alloc = byte_allocator_type()) : base_type(number_of_coordinates, coordinate_size, stride, alloc), m_size(number_of_coordinates) { base_type::reserve(number_of_coordinates); @@ -67,13 +67,13 @@ class CoordinatesCPU size_type N = (coordinate_end - coordinate_begin) / m_coordinate_size; base_type::allocate(N); // copy data directly to the ptr - std::copy_n(base_type::coordinate_data(), N * m_coordinate_size, - coordinate_begin); + std::copy_n(coordinate_begin, N * m_coordinate_size, + base_type::coordinate_data()); } void copy_coordinates(coordinate_type *dst_coordinate) const { - std::copy_n(dst_coordinate, size() * m_coordinate_size, - base_type::const_coordinate_data()); + std::copy_n(base_type::const_coordinate_data(), size() * m_coordinate_size, + dst_coordinate); } inline size_type size() const noexcept { return m_size; } diff --git a/src/coordinate_map_gpu.cu b/src/coordinate_map_gpu.cu index 452e33b4..63a5ef40 100644 --- a/src/coordinate_map_gpu.cu +++ b/src/coordinate_map_gpu.cu @@ -1364,14 +1364,14 @@ void CoordinateMapGPU::copy_coordinates( } // Template instantiation -template class CoordinatesGPU; -template class CoordinatesGPU; -template class CoordinatesGPU; -template class CoordinatesGPU; +template class CoordinateFieldMapGPU; +template class CoordinateFieldMapGPU; +template class CoordinateFieldMapGPU; +template class CoordinateFieldMapGPU; template class CoordinateMapGPU; diff --git a/src/coordinate_map_gpu.cuh b/src/coordinate_map_gpu.cuh index 84eac4ed..a819053d 100644 --- a/src/coordinate_map_gpu.cuh +++ b/src/coordinate_map_gpu.cuh @@ -40,23 +40,23 @@ namespace minkowski { template class TemplatedAllocator = detail::c10_allocator> -class CoordinatesGPU +class CoordinateFieldMapGPU : public CoordinateMap { // Coordinate wrapper public: using base_type = CoordinateMap; - using self_type = CoordinatesGPU; + using self_type = CoordinateFieldMapGPU; using size_type = typename base_type::size_type; using index_type = typename base_type::index_type; using stride_type = typename base_type::stride_type; using byte_allocator_type = TemplatedAllocator; public: - CoordinatesGPU() = delete; - CoordinatesGPU(size_type const number_of_coordinates, - size_type const coordinate_size, - stride_type const &stride = {1}, - byte_allocator_type alloc = byte_allocator_type()) + CoordinateFieldMapGPU() = delete; + CoordinateFieldMapGPU(size_type const number_of_coordinates, + size_type const coordinate_size, + stride_type const &stride = {1}, + byte_allocator_type alloc = byte_allocator_type()) : base_type(number_of_coordinates, coordinate_size, stride, alloc), m_size(number_of_coordinates) { base_type::reserve(number_of_coordinates); diff --git a/src/coordinate_map_manager.cpp b/src/coordinate_map_manager.cpp index b75c1c54..32789e5f 100644 --- a/src/coordinate_map_manager.cpp +++ b/src/coordinate_map_manager.cpp @@ -138,14 +138,14 @@ CoordsManager::getUnionMap(vector py_in_coords_keys, namespace detail { -template -struct insert_and_map_functor { - - std::pair operator()( - coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, - CoordinateMapManager - &manager) { +template +struct insert_and_map_functor { + + std::pair + operator()(coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, + CoordinateMapManager &manager) { LOG_DEBUG("initialize_and_map"); uint32_t const N = th_coordinate.size(0); uint32_t const coordinate_size = th_coordinate.size(1); @@ -185,6 +185,29 @@ struct insert_and_map_functor +struct insert_field_functor< + coordinate_type, coordinate_field_type, std::allocator, CoordinateMapCPU, + CoordinateFieldMapCPU> { + + void + operator()(coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, + CoordinateMapManager &manager) { + LOG_DEBUG("insert field"); + uint32_t const N = th_coordinate.size(0); + uint32_t const coordinate_size = th_coordinate.size(1); + coordinate_field_type *p_coordinate = + th_coordinate.data_ptr(); + auto map = CoordinateFieldMapCPU( + N, coordinate_size, map_key.first); + map.insert(p_coordinate, p_coordinate + N * coordinate_size); + + LOG_DEBUG("insert map with tensor_stride", map_key.first); + manager.insert_field_map(map_key, map); + } +}; + } // namespace detail /* @@ -197,12 +220,72 @@ struct insert_and_map_functor class TemplatedAllocator, + template class A> + class CoordinateMapType> +py::object CoordinateMapManager:: + insert_field(at::Tensor const &coordinates, + default_types::stride_type const tensor_stride, + std::string const string_id) { + + torch::TensorArg arg_coordinate(coordinates, "coordinates", 0); + torch::CheckedFrom c = "initialize"; + torch::checkContiguous(c, arg_coordinate); + + // must match coordinate_type + torch::checkScalarType(c, arg_coordinate, torch::kFloat); + torch::checkBackend(c, arg_coordinate.tensor, + detail::is_cpu_coordinate_map::value + ? torch::Backend::CPU + : torch::Backend::CUDA); + torch::checkDim(c, arg_coordinate, 2); + + auto const coordinate_size = (index_type)coordinates.size(1); + + // Basic assertions + ASSERT(coordinate_size - 1 == tensor_stride.size(), + "The coordinate dimension (coordinate_size - 1):", coordinate_size - 1, + " must match the size of tensor stride:", ArrToString(tensor_stride)); + + // generate the map_key + coordinate_map_key_type map_key = std::make_pair(tensor_stride, string_id); + if (m_field_coordinates.find(map_key) != m_field_coordinates.end()) { + WARNING(true, "CoordinateMapKey collision detected:", map_key, + "generating new string id."); + map_key = get_random_string_id(tensor_stride, string_id); + } + + LOG_DEBUG("initializing a map with tensor stride:", map_key.first, + "string id:", map_key.second); + // Create the concurrent coords map + detail::insert_field_functor()(map_key, coordinates, *this); + + py::object py_key = py::cast(new CoordinateMapKey(coordinate_size, map_key)); + + return py_key; +} + +/* + * coords: coordinates in IntTensor + * mapping: output mapping in IntTensor + * tensor_strides: current tensor strides this coords will be initializeds + * force_creation: even when there's a duplicate coords with the same tensor + * strides. + * force_remap: if there's duplicate coords, remap + * allow_duplicate_coords: create map when there are duplicates in the + * coordinates + */ +template class TemplatedAllocator, template class A> class CoordinateMapType> std::pair> -CoordinateMapManager:: +CoordinateMapManager:: insert_and_map(at::Tensor const &coordinate, default_types::stride_type const tensor_stride, std::string const string_id) { @@ -237,9 +320,9 @@ CoordinateMapManager:: "string id:", map_key.second); // Create the concurrent coords map auto const map_inverse_map = - detail::insert_and_map_functor()(map_key, coordinate, - *this); + detail::insert_and_map_functor()( + map_key, coordinate, *this); py::object py_key = py::cast(new CoordinateMapKey(coordinate_size, map_key)); @@ -247,14 +330,14 @@ CoordinateMapManager:: } // stride -template class TemplatedAllocator, template class A> class CoordinateMapType> -std::pair -CoordinateMapManager:: - stride(coordinate_map_key_type const &in_map_key, - stride_type const &kernel_stride) { +std::pair CoordinateMapManager< + coordinate_type, coordinate_field_type, TemplatedAllocator, + CoordinateMapType>::stride(coordinate_map_key_type const &in_map_key, + stride_type const &kernel_stride) { ASSERT(exists(in_map_key), ERROR_MAP_NOT_FOUND); // check if the key exists. LOG_DEBUG("In tensor stride:", in_map_key.first, @@ -274,12 +357,13 @@ CoordinateMapManager:: return std::make_pair(out_map_key, !exists_out_map); } -template class TemplatedAllocator, template class A> class CoordinateMapType> std::pair -CoordinateMapManager:: +CoordinateMapManager:: stride_region(coordinate_map_key_type const &in_map_key, cpu_kernel_region &kernel, bool generate_new_map) { @@ -308,12 +392,12 @@ CoordinateMapManager:: return std::make_pair(out_map_key, !exists_out_map || generate_new_map); } -template class TemplatedAllocator, template class A> class CoordinateMapType> std::pair -CoordinateMapManager::origin() { ASSERT(m_coordinate_maps.size() > 0, "No coordinate map found"); // check if the key exists. @@ -349,14 +433,16 @@ CoordinateMapManager class TemplatedAllocator, template class A> class CoordinateMapType> coordinate_map_key_type -CoordinateMapManager:: - prune(coordinate_map_key_type const &in_key, bool const *keep_begin, - bool const *keep_end) { +CoordinateMapManager::prune(coordinate_map_key_type const + &in_key, + bool const *keep_begin, + bool const *keep_end) { ASSERT(exists(in_key), "In map doesn't exist"); @@ -419,15 +505,17 @@ template <> struct swap_in_out_map_functor { * Given tensor_stride_src and tensor_stride_dst, find the respective coord_maps * and return the indices of the coord_map_ind in coord_map_dst */ -template class TemplatedAllocator, template class A> class CoordinateMapType> -typename CoordinateMapManager::kernel_map_type const & -CoordinateMapManager:: - kernel_map(CoordinateMapKey const *p_in_map_key, - CoordinateMapKey const *p_out_map_key) { +CoordinateMapManager< + coordinate_type, coordinate_field_type, TemplatedAllocator, + CoordinateMapType>::kernel_map(CoordinateMapKey const *p_in_map_key, + CoordinateMapKey const *p_out_map_key) { // when kernel has volume 1 auto const &map_it = m_coordinate_maps.find(p_in_map_key->get_key()); ASSERT(map_it != m_coordinate_maps.end(), ERROR_MAP_NOT_FOUND); @@ -444,20 +532,23 @@ CoordinateMapManager:: * Given tensor_stride_src and tensor_stride_dst, find the respective coord_maps * and return the indices of the coord_map_ind in coord_map_dst */ -template class TemplatedAllocator, template class A> class CoordinateMapType> -typename CoordinateMapManager::kernel_map_type const & -CoordinateMapManager:: - kernel_map(CoordinateMapKey const *p_in_map_key, - CoordinateMapKey const *p_out_map_key, - stride_type const &kernel_size, // - stride_type const &kernel_stride, - stride_type const &kernel_dilation, - RegionType::Type const region_type, at::Tensor const &offset, - bool is_transpose, bool is_pool) { +CoordinateMapManager< + coordinate_type, coordinate_field_type, TemplatedAllocator, + CoordinateMapType>::kernel_map(CoordinateMapKey const *p_in_map_key, + CoordinateMapKey const *p_out_map_key, + stride_type const &kernel_size, // + stride_type const &kernel_stride, + stride_type const &kernel_dilation, + RegionType::Type const region_type, + at::Tensor const &offset, bool is_transpose, + bool is_pool) { ASSERT(region_type != RegionType::CUSTOM, "Not implemented yet."); if (region_type == RegionType::CUSTOM) ASSERT(offset.is_cuda() == @@ -663,14 +754,16 @@ struct origin_map_functor class TemplatedAllocator, template class A> class CoordinateMapType> -typename CoordinateMapManager::kernel_map_type const & -CoordinateMapManager:: - origin_map(CoordinateMapKey const *p_in_map_key) { +CoordinateMapManager::origin_map(CoordinateMapKey const + *p_in_map_key) { ASSERT(exists(p_in_map_key), ERROR_MAP_NOT_FOUND); kernel_map_key_type const kernel_map_key = origin_map_key(p_in_map_key->get_key()); @@ -687,13 +780,14 @@ CoordinateMapManager:: return m_kernel_maps[kernel_map_key]; } -template class TemplatedAllocator, template class A> class CoordinateMapType> std::pair> -CoordinateMapManager:: - origin_map_th(CoordinateMapKey const *p_in_map_key) { +CoordinateMapManager::origin_map_th(CoordinateMapKey const + *p_in_map_key) { kernel_map_type const &kernel_map = origin_map(p_in_map_key); coordinate_map_key_type const origin_key = origin().first; @@ -862,13 +956,14 @@ CoordsManager::getUnionInOutMaps(vector py_in_coords_keys, */ /* Helper functions */ -template class TemplatedAllocator, template class A> class CoordinateMapType> at::Tensor -CoordinateMapManager:: - get_coordinates(CoordinateMapKey const *p_key) const { +CoordinateMapManager::get_coordinates(CoordinateMapKey const + *p_key) const { ASSERT(exists(p_key), ERROR_MAP_NOT_FOUND); auto const it = m_coordinate_maps.find(p_key->get_key()); ASSERT(it != m_coordinate_maps.end(), ERROR_MAP_NOT_FOUND); @@ -893,7 +988,43 @@ CoordinateMapManager:: return coordinates; } +template class TemplatedAllocator, + template class A> + class CoordinateMapType> +at::Tensor CoordinateMapManager:: + get_coordinate_field(CoordinateMapKey const *p_key) const { + ASSERT(exists(p_key), ERROR_MAP_NOT_FOUND); + auto const it = m_field_coordinates.find(p_key->get_key()); + ASSERT(it != m_field_coordinates.end(), ERROR_MAP_NOT_FOUND); + auto const &map = it->second; + auto const nrows = map.size(); + auto const ncols = map.coordinate_size(); + + auto options = torch::TensorOptions() + .dtype(std::is_same::value + ? torch::kFloat + : torch::kDouble) + .requires_grad(false); + + if (!detail::is_cpu_coordinate_map::value) { +#ifndef CPU_ONLY + auto device_id = at::cuda::current_device(); + options = options.device(torch::kCUDA, device_id); +#else + ASSERT(false, ERROR_CPU_ONLY); +#endif + } + at::Tensor coordinates = torch::empty({(long)nrows, (long)ncols}, options); + + // copy to the out coords + map.copy_coordinates(coordinates.template data_ptr()); + return coordinates; +} + template class CoordinateMapManager; } // end namespace minkowski diff --git a/src/coordinate_map_manager.cu b/src/coordinate_map_manager.cu index 333a2071..52d89015 100644 --- a/src/coordinate_map_manager.cu +++ b/src/coordinate_map_manager.cu @@ -42,15 +42,15 @@ __global__ void cuda_copy_n(src_type const *src, uint32_t N, dst_type *dst) { CUDA_KERNEL_LOOP(index, N) { dst[index] = src[index]; } } -template class TemplatedAllocator> -struct insert_and_map_functor { +struct insert_and_map_functor { - std::pair - operator()(coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, - CoordinateMapManager &manager) { + std::pair operator()( + coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, + CoordinateMapManager &manager) { uint32_t const N = th_coordinate.size(0); uint32_t const coordinate_size = th_coordinate.size(1); coordinate_type *p_coordinate = th_coordinate.data_ptr(); @@ -109,6 +109,31 @@ struct insert_and_map_functor class TemplatedAllocator> +struct insert_field_functor< + coordinate_type, coordinate_field_type, TemplatedAllocator, + CoordinateMapGPU, + CoordinateFieldMapGPU> { + + void operator()( + coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, + CoordinateMapManager &manager) { + LOG_DEBUG("insert field"); + uint32_t const N = th_coordinate.size(0); + uint32_t const coordinate_size = th_coordinate.size(1); + coordinate_field_type *p_coordinate = + th_coordinate.data_ptr(); + auto map = CoordinateFieldMapGPU( + N, coordinate_size, map_key.first); + map.insert(p_coordinate, p_coordinate + N * coordinate_size); + + LOG_DEBUG("insert map with tensor_stride", map_key.first); + manager.insert_field_map(map_key, map); + } +}; + template class TemplatedAllocator> struct kernel_map_functor< @@ -270,9 +295,11 @@ struct origin_map_functor< } // namespace detail -template class CoordinateMapManager; -template class CoordinateMapManager; +template class CoordinateMapManager< + default_types::dcoordinate_type, default_types::ccoordinate_type, + detail::default_allocator, CoordinateMapGPU>; +template class CoordinateMapManager; } // end namespace minkowski diff --git a/src/coordinate_map_manager.hpp b/src/coordinate_map_manager.hpp index 85b75edd..74efb74c 100644 --- a/src/coordinate_map_manager.hpp +++ b/src/coordinate_map_manager.hpp @@ -80,7 +80,7 @@ default_types::stride_type _fill_vec(size_t const len) { } // namespace detail -template class TemplatedAllocator, template class A> class CoordinateMapType> @@ -91,15 +91,16 @@ class CoordinateMapManager { using stride_type = default_types::stride_type; using map_type = CoordinateMapType; #ifndef CPU_ONLY - using coordinates_type = typename std::conditional< + using field_map_type = typename std::conditional< detail::is_cpu_coordinate_map::value, - CoordinatesCPU, - CoordinatesGPU>::type; + CoordinateFieldMapCPU, + CoordinateFieldMapGPU>::type; #else - using coordinates_type = CoordinatesCPU; + using field_map_type = + CoordinateFieldMapCPU; #endif - using self_type = CoordinateMapManager; + using self_type = CoordinateMapManager; using map_collection_type = std::map; using kernel_map_type = @@ -155,10 +156,9 @@ class CoordinateMapManager { /**************************************************************************** * Coordinate generation, modification, and initialization entry functions ****************************************************************************/ - // TODO - // py::object insert(at::Tensor const &th_coordinate, - // stride_type const tensor_stride, - // std::string const string_id = ""); + py::object insert_field(at::Tensor const &th_coordinate, + stride_type const tensor_stride, + std::string const string_id = ""); /* * New coordinate map initialzation function. @@ -223,6 +223,15 @@ class CoordinateMapManager { return result.second; } + bool insert_field_map(coordinate_map_key_type map_key, field_map_type &map) { + LOG_DEBUG("insert map with tensor_stride", map_key.first); + auto result = m_field_coordinates.insert( + std::make_pair( + std::move(map_key), std::move(map))); + LOG_DEBUG("map insertion", result.second); + return result.second; + } + typename map_collection_type::iterator find(coordinate_map_key_type const &map_key) { return m_coordinate_maps.find(map_key); @@ -260,6 +269,8 @@ class CoordinateMapManager { at::Tensor get_coordinates(CoordinateMapKey const *p_key) const; + at::Tensor get_coordinate_field(CoordinateMapKey const *p_key) const; + std::vector get_coordinate_map_keys(stride_type const tensor_stride) const { std::vector keys; @@ -395,9 +406,9 @@ class CoordinateMapManager { m_coordinate_maps; // CoordinateMapManager managed coordinates - std::map - m_coordinates; + m_field_coordinates; // CoordinateMapManager owns the kernel maps std::unordered_map class TemplatedAllocator, + template class A> + class CoordinateMapType, + typename field_map_type> +struct insert_field_functor { + + void operator()( + coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, + CoordinateMapManager &manager); +}; + // a partial specialization functor for insertion -template class TemplatedAllocator, template class A> class CoordinateMapType> struct insert_and_map_functor { - std::pair - operator()(coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, - CoordinateMapManager &manager); + std::pair operator()( + coordinate_map_key_type &map_key, at::Tensor const &th_coordinate, + CoordinateMapManager &manager); }; // a partial specialization functor for kernel map generation @@ -482,13 +506,15 @@ struct origin_map_functor { // type defs template using cpu_manager_type = - CoordinateMapManager; + CoordinateMapManager; #ifndef CPU_ONLY template class TemplatedAllocator> using gpu_manager_type = - CoordinateMapManager; + CoordinateMapManager; template using gpu_default_manager_type = diff --git a/tests/python/tensor_field.py b/tests/python/tensor_field.py index f3a45aa9..5fa34a98 100644 --- a/tests/python/tensor_field.py +++ b/tests/python/tensor_field.py @@ -101,8 +101,8 @@ def test_network_device(self): def slice(self): coords, colors, pcd = load_file("1.ply") voxel_size = 0.02 - colors = torch.from_numpy(colors) - bcoords = batched_coordinates([coords / voxel_size], return_int=False) + colors = torch.from_numpy(colors).float() + bcoords = batched_coordinates([coords / voxel_size], return_int=False).float() tfield = TensorField(colors, bcoords) network = nn.Sequential(