Skip to content

Commit

Permalink
field map insertion
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent 75e9e37 commit d30e6c7
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 141 deletions.
32 changes: 28 additions & 4 deletions MinkowskiEngine/MinkowskiCoordinateManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,15 @@ def __init__(

def insert_and_map(
self,
coordinates: torch.IntTensor,
tensor_stride: Union[int, Sequence, np.ndarray, torch.Tensor],
coordinates: torch.Tensor,
tensor_stride: Union[int, Sequence, np.ndarray],
string_id: str = "",
) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]:
r"""create a new coordinate map and returns
:attr:`coordinates`: `torch.IntTensor` (`CUDA` if coordinate_map_type
== `CoordinateMapType.GPU`) that defines the coordinates.
:attr:`coordinates`: `torch.Tensor` (Int tensor. `CUDA` if
coordinate_map_type == `CoordinateMapType.GPU`) that defines the
coordinates.
Example::
Expand All @@ -168,6 +169,29 @@ def insert_and_map(
"""
return self._manager.insert_and_map(coordinates, tensor_stride, string_id)

def insert_field(
self,
coordinates: torch.Tensor,
tensor_stride: Union[int, Sequence, np.ndarray],
string_id: str = "",
) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]:
r"""create a new coordinate map and returns
:attr:`coordinates`: `torch.FloatTensor` (`CUDA` if coordinate_map_type
== `CoordinateMapType.GPU`) that defines the coordinates.
Example::
>>> manager = CoordinateManager(D=1)
>>> coordinates = torch.FloatTensor([[0, 0.1], [0, 2.3], [0, 1.2], [0, 2.4]])
>>> key, (unique_map, inverse_map) = manager.insert(coordinates, [1])
>>> print(key) # key is tensor_stride, string_id [1]:""
>>> torch.all(coordinates[unique_map] == manager.get_coordinates(key)) # True
>>> torch.all(coordinates == coordinates[unique_map][inverse_map]) # True
"""
return self._manager.insert_field(coordinates, tensor_stride, string_id)

def stride(
self,
coordinate_map_key: CoordinateMapKey,
Expand Down
5 changes: 3 additions & 2 deletions MinkowskiEngine/MinkowskiNonlinearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ def __init__(self, *args, **kwargs):
def forward(self, input):
output = self.module(input.F)
if isinstance(input, TensorField):
return input.__class__(
return TensorField(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_field_map_key=input.coordinate_field_map_key,
coordinate_manager=input.coordinate_manager,
inverse_mapping=input.inverse_mapping,
quantization_mode=input.quantization_mode,
)
else:
return input.__class__(
return SparseTensor(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_manager=input.coordinate_manager,
Expand Down
5 changes: 3 additions & 2 deletions MinkowskiEngine/MinkowskiNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ def __init__(
def forward(self, input):
output = self.bn(input.F)
if isinstance(input, TensorField):
return input.__class__(
return TensorField(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_field_map_key=input.coordinate_field_map_key,
coordinate_manager=input.coordinate_manager,
inverse_mapping=input.inverse_mapping,
quantization_mode=input.quantization_mode,
)
else:
return input.__class__(
return SparseTensor(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_manager=input.coordinate_manager,
Expand Down
5 changes: 3 additions & 2 deletions MinkowskiEngine/MinkowskiOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ def __init__(self, in_features, out_features, bias=True):
def forward(self, input: Union[SparseTensor, TensorField]):
output = self.linear(input.F)
if isinstance(input, TensorField):
return input.__class__(
return TensorField(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_field_map_key=input.coordinate_field_map_key,
coordinate_manager=input.coordinate_manager,
inverse_mapping=input.inverse_mapping,
quantization_mode=input.quantization_mode,
)
else:
return input.__class__(
return SparseTensor(
output,
coordinate_map_key=input.coordinate_map_key,
coordinate_manager=input.coordinate_manager,
Expand Down
55 changes: 40 additions & 15 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,24 @@ def slice(self, X, slicing_mode=0):
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
from MinkowskiTensorField import TensorField

return TensorField(
self.F[X.inverse_mapping],
coordinate_map_key=self.coordinate_map_key,
coordinate_manager=self.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)
if isinstance(X, TensorField):
return TensorField(
self.F[X.inverse_mapping],
coordinate_map_key=X.coordinate_map_key,
coordinate_field_map_key=X.coordinate_field_map_key,
coordinate_manager=X.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)
else:
return TensorField(
self.F[X.inverse_mapping],
coordinates=self.C[X.inverse_mapping],
coordinate_map_key=X.coordinate_map_key,
coordinate_manager=X.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)

def cat_slice(self, X, slicing_mode=0):
r"""
Expand Down Expand Up @@ -428,13 +439,25 @@ def cat_slice(self, X, slicing_mode=0):
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
from MinkowskiTensorField import TensorField

return TensorField(
torch.cat((self.F[X.inverse_mapping], X.F), dim=1),
coordinate_map_key=self.coordinate_map_key,
coordinate_manager=self.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)
features = torch.cat((self.F[X.inverse_mapping], X.F), dim=1)
if isinstance(X, TensorField):
return TensorField(
features,
coordinate_map_key=X.coordinate_map_key,
coordinate_field_map_key=X.coordinate_field_map_key,
coordinate_manager=X.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)
else:
return TensorField(
features,
coordinates=self.C[X.inverse_mapping],
coordinate_map_key=X.coordinate_map_key,
coordinate_manager=X.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)

def features_at_coords(self, query_coords: torch.Tensor):
r"""Extract features at the specified coordinate matrix.
Expand Down Expand Up @@ -501,7 +524,9 @@ def _get_coordinate_map_key(
(
coordinate_map_key,
(unique_index, inverse_mapping),
) = input._manager.insert_and_map(coordinates, *coordinate_map_key.get_key())
) = input._manager.insert_and_map(
coordinates, *coordinate_map_key.get_key()
)
elif isinstance(coordinates, SparseTensor):
coordinate_map_key = coordinates.coordinate_map_key
else: # CoordinateMapKey type due to the previous assertion
Expand Down
24 changes: 18 additions & 6 deletions MinkowskiEngine/MinkowskiTensorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
# optional coordinate related arguments
tensor_stride: StrideType = 1,
coordinate_map_key: CoordinateMapKey = None,
coordinate_field_map_key: CoordinateMapKey = None,
coordinate_manager: CoordinateManager = None,
quantization_mode: SparseTensorQuantizationMode = SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
# optional manager related arguments
Expand Down Expand Up @@ -80,10 +81,20 @@ def __init__(
if inverse_mapping is not None:
self.inverse_mapping = inverse_mapping

self.coordinate_field_map_key = coordinate_field_map_key
if coordinate_field_map_key is None:
assert coordinates is not None
self._CC = coordinates.float()
self.coordinate_field_map_key = self._manager.insert_field(
self._CC, *self.coordinate_map_key.get_key()
)

def initialize_coordinates(self, coordinates, features, coordinate_map_key):
self._CC = coordinates
assert not isinstance(coordinates, (torch.IntTensor, torch.cuda.IntTensor))
int_coordinates = torch.floor(coordinates).int()

if not isinstance(coordinates, (torch.IntTensor, torch.cuda.IntTensor)):
int_coordinates = torch.floor(coordinates).int()
else:
int_coordinates = coordinates

(
self.coordinate_map_key,
Expand Down Expand Up @@ -137,11 +148,11 @@ def coordinates(self):
different instances in a batch.
"""
if self._CC is None:
self._CC = self._get_continuous_coordinates()
self._CC = self._get_coordinate_field()
return self._CC

def _get_continuous_coordinates(self):
return self._manager.get_continuous_coordinates(self.coordinate_map_key)
def _get_coordinate_field(self):
return self._manager.get_coordinate_field(self.coordinate_field_map_key)

def sparse(self):
r"""Converts the current sparse tensor field to a sparse tensor."""
Expand Down Expand Up @@ -172,6 +183,7 @@ def sparse(self):
"_F",
"_D",
"coordinate_map_key",
"coordinate_field_map_key",
"_manager",
"unique_index",
"inverse_mapping",
Expand Down
1 change: 1 addition & 0 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ void instantiate_manager(py::module &m, const std::string &dtypestr) {
py::overload_cast<minkowski::CoordinateMapKey const *>(
&manager_type::to_string, py::const_))
.def("insert_and_map", &manager_type::insert_and_map)
.def("insert_field", &manager_type::insert_field)
.def("stride", &manager_type::py_stride)
.def("origin", &manager_type::py_origin)
.def("get_coordinates", &manager_type::get_coordinates)
Expand Down
22 changes: 11 additions & 11 deletions src/coordinate_map_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ namespace minkowski {

template <typename coordinate_type,
template <typename T> class TemplatedAllocator = std::allocator>
class CoordinatesCPU
class CoordinateFieldMapCPU
: public CoordinateMap<coordinate_type, TemplatedAllocator> {
// Coordinate wrapper
public:
using base_type = CoordinateMap<coordinate_type, TemplatedAllocator>;
using self_type = CoordinatesCPU<coordinate_type, TemplatedAllocator>;
using self_type = CoordinateFieldMapCPU<coordinate_type, TemplatedAllocator>;
using size_type = typename base_type::size_type;
using index_type = typename base_type::index_type;
using stride_type = typename base_type::stride_type;
using byte_allocator_type = TemplatedAllocator<char>;

public:
CoordinatesCPU() = delete;
CoordinatesCPU(size_type const number_of_coordinates,
size_type const coordinate_size,
stride_type const &stride = {1},
byte_allocator_type alloc = byte_allocator_type())
CoordinateFieldMapCPU() = delete;
CoordinateFieldMapCPU(size_type const number_of_coordinates,
size_type const coordinate_size,
stride_type const &stride = {1},
byte_allocator_type alloc = byte_allocator_type())
: base_type(number_of_coordinates, coordinate_size, stride, alloc),
m_size(number_of_coordinates) {
base_type::reserve(number_of_coordinates);
Expand All @@ -67,13 +67,13 @@ class CoordinatesCPU
size_type N = (coordinate_end - coordinate_begin) / m_coordinate_size;
base_type::allocate(N);
// copy data directly to the ptr
std::copy_n(base_type::coordinate_data(), N * m_coordinate_size,
coordinate_begin);
std::copy_n(coordinate_begin, N * m_coordinate_size,
base_type::coordinate_data());
}

void copy_coordinates(coordinate_type *dst_coordinate) const {
std::copy_n(dst_coordinate, size() * m_coordinate_size,
base_type::const_coordinate_data());
std::copy_n(base_type::const_coordinate_data(), size() * m_coordinate_size,
dst_coordinate);
}

inline size_type size() const noexcept { return m_size; }
Expand Down
16 changes: 8 additions & 8 deletions src/coordinate_map_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1364,14 +1364,14 @@ void CoordinateMapGPU<coordinate_type, TemplatedAllocator>::copy_coordinates(
}

// Template instantiation
template class CoordinatesGPU<default_types::dcoordinate_type,
detail::default_allocator>;
template class CoordinatesGPU<default_types::dcoordinate_type,
detail::c10_allocator>;
template class CoordinatesGPU<default_types::ccoordinate_type,
detail::default_allocator>;
template class CoordinatesGPU<default_types::ccoordinate_type,
detail::c10_allocator>;
template class CoordinateFieldMapGPU<default_types::dcoordinate_type,
detail::default_allocator>;
template class CoordinateFieldMapGPU<default_types::dcoordinate_type,
detail::c10_allocator>;
template class CoordinateFieldMapGPU<default_types::ccoordinate_type,
detail::default_allocator>;
template class CoordinateFieldMapGPU<default_types::ccoordinate_type,
detail::c10_allocator>;

template class CoordinateMapGPU<default_types::dcoordinate_type,
detail::default_allocator>;
Expand Down
14 changes: 7 additions & 7 deletions src/coordinate_map_gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,23 @@ namespace minkowski {
template <typename coordinate_type, template <typename T>
class TemplatedAllocator =
detail::c10_allocator>
class CoordinatesGPU
class CoordinateFieldMapGPU
: public CoordinateMap<coordinate_type, TemplatedAllocator> {
// Coordinate wrapper
public:
using base_type = CoordinateMap<coordinate_type, TemplatedAllocator>;
using self_type = CoordinatesGPU<coordinate_type, TemplatedAllocator>;
using self_type = CoordinateFieldMapGPU<coordinate_type, TemplatedAllocator>;
using size_type = typename base_type::size_type;
using index_type = typename base_type::index_type;
using stride_type = typename base_type::stride_type;
using byte_allocator_type = TemplatedAllocator<char>;

public:
CoordinatesGPU() = delete;
CoordinatesGPU(size_type const number_of_coordinates,
size_type const coordinate_size,
stride_type const &stride = {1},
byte_allocator_type alloc = byte_allocator_type())
CoordinateFieldMapGPU() = delete;
CoordinateFieldMapGPU(size_type const number_of_coordinates,
size_type const coordinate_size,
stride_type const &stride = {1},
byte_allocator_type alloc = byte_allocator_type())
: base_type(number_of_coordinates, coordinate_size, stride, alloc),
m_size(number_of_coordinates) {
base_type::reserve(number_of_coordinates);
Expand Down
Loading

0 comments on commit d30e6c7

Please sign in to comment.