Skip to content

Commit

Permalink
handle an empty tensor in dense func (NVIDIA#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Sep 1, 2021
1 parent 0ec0fb4 commit 5778fc5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Fix for GPU `coo_spmm` when nnz == 0
- Fix MinkowskiInterpolationGPU for invalid samples (issue #383)
- gradcheck wrap func 1.9
- Handle `dense` for an empty tensor (issue #384)

## [0.5.4]

Expand Down
25 changes: 19 additions & 6 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,12 @@ def __init__(
assert isinstance(coordinates, torch.Tensor)
if coordinate_map_key is not None:
assert isinstance(coordinate_map_key, CoordinateMapKey)
assert coordinate_manager is not None, "Must provide coordinate_manager if coordinate_map_key is provided"
assert coordinates is None, "Must not provide coordinates if coordinate_map_key is provided"
assert (
coordinate_manager is not None
), "Must provide coordinate_manager if coordinate_map_key is provided"
assert (
coordinates is None
), "Must not provide coordinates if coordinate_map_key is provided"
if coordinate_manager is not None:
assert isinstance(coordinate_manager, CoordinateManager)
if coordinates is None and (
Expand Down Expand Up @@ -404,7 +408,6 @@ def torch_sparse_Tensor(coords, feats, size=None):
coords = self.C
coords, batch_indices = coords[:, 1:], coords[:, 0]

# TODO, batch first
if min_coords is None:
min_coords, _ = coords.min(0, keepdim=True)
elif min_coords.ndim == 1:
Expand Down Expand Up @@ -491,13 +494,21 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True):
if shape[1] != self._F.size(1):
shape = torch.Size([shape[0], self._F.size(1), *[s for s in shape[2:]]])

# Exception handling for empty tensor
if self.__len__() == 0:
assert shape is not None, "shape is required to densify an empty tensor"
return (
torch.zeros(shape, dtype=self.dtype, device=self.device),
torch.zeros(self._D, dtype=torch.int32, device=self.device),
self.tensor_stride,
)

# Use int tensor for all operations
tensor_stride = torch.IntTensor(self.tensor_stride).to(self.device)

# New coordinates
batch_indices = self.C[:, 0]

# TODO, batch first
if min_coordinate is None:
min_coordinate, _ = self.C.min(0, keepdim=True)
min_coordinate = min_coordinate[:, 1:]
Expand Down Expand Up @@ -528,9 +539,11 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True):
nchannels = self.F.size(1)
if shape is None:
size = coords.max(0)[0] + 1
shape = torch.Size([batch_indices.max() + 1, nchannels, *size.cpu().numpy()])
shape = torch.Size(
[batch_indices.max() + 1, nchannels, *size.cpu().numpy()]
)

dense_F = torch.zeros(shape, dtype=self.F.dtype, device=self.F.device)
dense_F = torch.zeros(shape, dtype=self.dtype, device=self.device)

tcoords = coords.t().long()
batch_indices = batch_indices.long()
Expand Down
10 changes: 10 additions & 0 deletions tests/python/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def test(self):

print(feats.grad)

def test_empty(self):
x = torch.zeros(4, 1, 34, 34)
to_dense = ME.MinkowskiToDenseTensor(x.shape)

# Convert to sparse data
sparse_data = ME.to_sparse(x)
dense_data = to_dense(sparse_data)

self.assertEqual(dense_data.shape, x.shape)


class TestDenseToSparse(unittest.TestCase):
def test(self):
Expand Down

0 comments on commit 5778fc5

Please sign in to comment.