Skip to content

Commit

Permalink
decomposition permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent 2f4ce0c commit 1cdebb9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
15 changes: 8 additions & 7 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,18 +461,19 @@ def cat_slice(self, X, slicing_mode=0):
quantization_mode=X.quantization_mode,
)

def features_at_coords(self, query_coords: torch.Tensor):
def features_at_coordinates(self, query_coordinates: torch.Tensor):
r"""Extract features at the specified coordinate matrix.
Args:
:attr:`query_coords` (:attr:`torch.IntTensor`): a coordinate matrix
of size :math:`N \times (D + 1)` where :math:`D` is the size of the
spatial dimension.
:attr:`query_coordinates` (:attr:`torch.IntTensor`): a coordinate
matrix of size :math:`N \times (D + 1)` where :math:`D` is the size
of the spatial dimension.
Returns:
:attr:`query_feats` (:attr:`torch.Tensor`): a feature matrix of size
:math:`N \times D_F` where :math:`D_F` is the number of channels in
the feature. Features for the coordinates that are not found, it will be zero.
:attr:`queried_features` (:attr:`torch.Tensor`): a feature matrix of
size :math:`N \times D_F` where :math:`D_F` is the number of
channels in the feature. For coordinates not present in the current
sparse tensor, corresponding feature rows will be zeros.
:attr:`valid_rows` (:attr:`list`): a list of row indices that
contain valid values. The rest of the rows that are not found in the
Expand Down
76 changes: 76 additions & 0 deletions MinkowskiEngine/MinkowskiTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,32 @@ def _batchwise_row_indices(self):
_, self._batch_rows = self._manager.origin_map(self.coordinate_map_key)
return self._batch_rows

@property
def _sorted_batchwise_row_indices(self):
if self._sorted_batch_rows is None:
batch_rows = self._batchwise_row_indices
with torch.no_grad():
self._sorted_batch_rows = [t.sort()[0] for t in batch_rows]
return self._sorted_batch_rows

@property
def decomposition_permutations(self):
r"""Returns a list of indices per batch that where indices defines the permutation of the batch-wise decomposition.
Example::
>>> # coords, feats, labels are given. All follow the same order
>>> stensor = ME.SparseTensor(feats, coords)
>>> conv = ME.MinkowskiConvolution(in_channels=3, out_nchannel=3, kernel_size=3, dimension=3)
>>> list_of_featurs = stensor.decomposed_features
>>> list_of_permutations = stensor.decomposition_permutations
>>> # list_of_features == [feats[inds] for inds in list_of_permutations]
>>> list_of_decomposed_labels = [labels[inds] for inds in list_of_permutations]
>>> for curr_feats, curr_labels in zip(list_of_features, list_of_decomposed_labels):
>>> loss += torch.functional.mse_loss(curr_feats, curr_labels)
"""
return self._batchwise_row_indices

@property
def decomposed_coordinates(self):
r"""Returns a list of coordinates per batch.
Expand All @@ -419,6 +445,15 @@ def decomposed_coordinates(self):
\times D}` coordinates per batch where :math:`N_i` is the number of non
zero elements in the :math:`i`th batch index in :math:`D` dimensional
space.
.. note::
The order of coordinates is non-deterministic within each batch. Use
:attr:`decomposed_coordinates_and_features` to retrieve both
coordinates features with the same order. To retrieve the order the
decomposed coordinates is generated, use
:attr:`decomposition_permutations`.
"""
return [self.C[row_inds, 1:] for row_inds in self._batchwise_row_indices]

Expand All @@ -429,6 +464,15 @@ def coordinates_at(self, batch_index):
\times D}` coordinates at the specified batch index where :math:`N_i`
is the number of non zero elements in the :math:`i`th batch index in
:math:`D` dimensional space.
.. note::
The order of coordinates is non-deterministic within each batch. Use
:attr:`decomposed_coordinates_and_features` to retrieve both
coordinates features with the same order. To retrieve the order the
decomposed coordinates is generated, use
:attr:`decomposition_permutations`.
"""
return self.C[self._batchwise_row_indices[batch_index], 1:]

Expand All @@ -440,6 +484,15 @@ def decomposed_features(self):
\times N_F}` features per batch where :math:`N_i` is the number of non
zero elements in the :math:`i`th batch index in :math:`D` dimensional
space.
.. note::
The order of features is non-deterministic within each batch. Use
:attr:`decomposed_coordinates_and_features` to retrieve both
coordinates features with the same order. To retrieve the order the
decomposed features is generated, use
:attr:`decomposition_permutations`.
"""
return [self._F[row_inds] for row_inds in self._batchwise_row_indices]

Expand All @@ -450,6 +503,15 @@ def features_at(self, batch_index):
\times N_F}` feature matrix :math:`N` is the number of non
zero elements in the specified batch index and :math:`N_F` is the
number of channels.
.. note::
The order of features is non-deterministic within each batch. Use
:attr:`decomposed_coordinates_and_features` to retrieve both
coordinates features with the same order. To retrieve the order the
decomposed features is generated, use
:attr:`decomposition_permutations`.
"""
return self._F[self._batchwise_row_indices[batch_index]]

Expand All @@ -463,6 +525,13 @@ def coordinates_and_features_at(self, batch_index):
matrix is a torch.Tensor :math:`C \in \mathcal{R}^{N \times N_F}`
matrix :math:`N` is the number of non zero elements in the specified
batch index and :math:`N_F` is the number of channels.
.. note::
The order of features is non-deterministic within each batch. To
retrieve the order the decomposed features is generated, use
:attr:`decomposition_permutations`.
"""
row_inds = self._batchwise_row_indices[batch_index]
return self.C[row_inds, 1:], self._F[row_inds]
Expand All @@ -471,6 +540,13 @@ def coordinates_and_features_at(self, batch_index):
def decomposed_coordinates_and_features(self):
r"""Returns a list of coordinates and a list of features per batch.abs
.. note::
The order of decomposed coordinates and features is
non-deterministic within each batch. To retrieve the order the
decomposed features is generated, use
:attr:`decomposition_permutations`.
"""
row_inds_list = self._batchwise_row_indices
return (
Expand Down

0 comments on commit 1cdebb9

Please sign in to comment.