Skip to content

Commit

Permalink
tfield decomposition fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 31, 2020
1 parent 1b10a8e commit 032bb51
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions MinkowskiEngine/MinkowskiTensorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key):
spmm = MinkowskiSPMMFunction()
N = len(features)
cols = torch.arange(
N, dtype=self.inverse_mapping.dtype, device=self.inverse_mapping.device,
N,
dtype=self.inverse_mapping.dtype,
device=self.inverse_mapping.device,
)
vals = torch.ones(N, dtype=features.dtype, device=features.device)
size = torch.Size([len(self.unique_index), len(self.inverse_mapping)])
Expand All @@ -126,7 +128,11 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key):
== SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
):
nums = spmm.apply(
self.inverse_mapping, cols, vals, size, vals.reshape(N, 1),
self.inverse_mapping,
cols,
vals,
size,
vals.reshape(N, 1),
)
features /= nums
elif self.quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE:
Expand All @@ -139,8 +145,7 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key):

@property
def C(self):
r"""The alias of :attr:`coords`.
"""
r"""The alias of :attr:`coords`."""
return self.coordinates

@property
Expand All @@ -154,10 +159,17 @@ def coordinates(self):
internally treated as an additional spatial dimension to disassociate
different instances in a batch.
"""
if not hasattr(self, '_CC') or self._CC is None:
if not hasattr(self, "_CC") or self._CC is None:
self._CC = self._get_coordinate_field()
return self._CC

@property
def _batchwise_row_indices(self):
if self._batch_rows is None:
batch_inds = torch.unique(self._CC[:, 0])
self._batch_rows = [self._CC[:, 0] == b for b in batch_inds]
return self._batch_rows

def _get_coordinate_field(self):
return self._manager.get_coordinate_field(self.coordinate_field_map_key)

Expand All @@ -167,7 +179,9 @@ def sparse(self):
N = len(self._F)
assert N == len(self.inverse_mapping), "invalid inverse mapping"
cols = torch.arange(
N, dtype=self.inverse_mapping.dtype, device=self.inverse_mapping.device,
N,
dtype=self.inverse_mapping.dtype,
device=self.inverse_mapping.device,
)
vals = torch.ones(N, dtype=self._F.dtype, device=self._F.device)
size = torch.Size(
Expand All @@ -176,7 +190,13 @@ def sparse(self):
features = spmm.apply(self.inverse_mapping, cols, vals, size, self._F)
# int_inverse_mapping = self.inverse_mapping.int()
if self.quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE:
nums = spmm.apply(self.inverse_mapping, cols, vals, size, vals.reshape(N, 1),)
nums = spmm.apply(
self.inverse_mapping,
cols,
vals,
size,
vals.reshape(N, 1),
)
features /= nums

return SparseTensor(
Expand Down

0 comments on commit 032bb51

Please sign in to comment.