Skip to content

Commit

Permalink
update slice and tfield
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 3, 2021
1 parent 9e1322f commit f157ef8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
50 changes: 26 additions & 24 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,15 +512,13 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True):
tensor_stride = torch.IntTensor(self.tensor_stride)
return dense_F, min_coordinate, tensor_stride

def slice(self, X, slicing_mode=0):
def slice(self, X):
r"""
Args:
:attr:`X` (:attr:`MinkowskiEngine.SparseTensor`): a sparse tensor
that discretized the original input.
:attr:`slicing_mode`: For future updates.
Returns:
:attr:`tensor_field` (:attr:`MinkowskiEngine.TensorField`): the
resulting tensor field contains features on the continuous
Expand All @@ -530,7 +528,7 @@ def slice(self, X, slicing_mode=0):
>>> # coords, feats from a data loader
>>> print(len(coords)) # 227742
>>> tfield = ME.TensorField(coords=coords, feats=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> tfield = ME.TensorField(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> print(len(tfield)) # 227742
>>> sinput = tfield.sparse() # 161890 quantization results in fewer voxels
>>> soutput = MinkUNet(sinput)
Expand All @@ -545,9 +543,7 @@ def slice(self, X, slicing_mode=0):
SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
], "slice only available for sparse tensors with quantization RANDOM_SUBSAMPLE or UNWEIGHTED_AVERAGE"
assert (
X.coordinate_map_key == self.coordinate_map_key
), "Slice can only be applied on the same coordinates (coordinate_map_key)"

from MinkowskiTensorField import TensorField

if isinstance(X, TensorField):
Expand All @@ -557,23 +553,28 @@ def slice(self, X, slicing_mode=0):
coordinate_manager=X.coordinate_manager,
quantization_mode=X.quantization_mode,
)
else:
elif isinstance(X, SparseTensor):
assert (
X.coordinate_map_key == self.coordinate_map_key
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
return TensorField(
self.F[X.inverse_mapping],
coordinates=self.C[X.inverse_mapping],
coordinate_manager=X.coordinate_manager,
quantization_mode=X.quantization_mode,
coordinate_manager=self.coordinate_manager,
quantization_mode=self.quantization_mode,
)
else:
raise ValueError(
"Invalid input. The input must be an instance of TensorField or SparseTensor."
)

def cat_slice(self, X, slicing_mode=0):
def cat_slice(self, X):
r"""
Args:
:attr:`X` (:attr:`MinkowskiEngine.SparseTensor`): a sparse tensor
that discretized the original input.
:attr:`slicing_mode`: For future updates.
Returns:
:attr:`tensor_field` (:attr:`MinkowskiEngine.TensorField`): the
resulting tensor field contains the concatenation of features on the
Expand All @@ -584,7 +585,7 @@ def cat_slice(self, X, slicing_mode=0):
>>> # coords, feats from a data loader
>>> print(len(coords)) # 227742
>>> sinput = ME.SparseTensor(coords=coords, feats=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> sinput = ME.SparseTensor(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> print(len(sinput)) # 161890 quantization results in fewer voxels
>>> soutput = network(sinput)
>>> print(len(soutput)) # 161890 Output with the same resolution
Expand All @@ -596,29 +597,30 @@ def cat_slice(self, X, slicing_mode=0):
SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
], "slice only available for sparse tensors with quantization RANDOM_SUBSAMPLE or UNWEIGHTED_AVERAGE"
assert (
X.coordinate_map_key == self.coordinate_map_key
), "Slice can only be applied on the same coordinates (coordinate_map_key)"

from MinkowskiTensorField import TensorField

features = torch.cat((self.F[X.inverse_mapping], X.F), dim=1)
if isinstance(X, TensorField):
return TensorField(
features,
coordinate_map_key=X.coordinate_map_key,
coordinate_field_map_key=X.coordinate_field_map_key,
coordinate_manager=X.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
)
else:
elif isinstance(X, SparseTensor):
assert (
X.coordinate_map_key == self.coordinate_map_key
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
return TensorField(
features,
coordinates=self.C[X.inverse_mapping],
coordinate_map_key=X.coordinate_map_key,
coordinate_manager=X.coordinate_manager,
inverse_mapping=X.inverse_mapping,
quantization_mode=X.quantization_mode,
coordinate_manager=self.coordinate_manager,
quantization_mode=self.quantization_mode,
)
else:
raise ValueError(
"Invalid input. The input must be an instance of TensorField or SparseTensor."
)

def features_at_coordinates(self, query_coordinates: torch.Tensor):
Expand Down
15 changes: 14 additions & 1 deletion MinkowskiEngine/MinkowskiTensorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,25 @@ def sparse(self, quantization_mode=None):
if quantization_mode is None:
quantization_mode = self.quantization_mode

return SparseTensor(
sparse_tensor = SparseTensor(
self._F,
coordinates=self.coordinates,
quantization_mode=quantization_mode,
coordinate_manager=self.coordinate_manager,
)

# Save the inverse mapping
self._inverse_mapping = sparse_tensor.inverse_mapping
return sparse_tensor

@property
def inverse_mapping(self):
if not hasattr(self, "_inverse_mapping"):
raise ValueError(
"Did you run SparseTensor.slice? The slice must take a tensor field that returned TensorField.space."
)
return self._inverse_mapping

def __repr__(self):
return (
self.__class__.__name__
Expand Down Expand Up @@ -269,5 +281,6 @@ def __repr__(self):
"coordinate_field_map_key",
"_manager",
"quantization_mode",
"_inverse_mapping",
"_batch_rows",
)

0 comments on commit f157ef8

Please sign in to comment.