From 5778fc56ad6bc30e71fff7b156990bf9afd4ed7b Mon Sep 17 00:00:00 2001 From: chrischoy Date: Wed, 1 Sep 2021 13:31:46 -0700 Subject: [PATCH] handle an empty tensor in dense func (#384) --- CHANGELOG.md | 1 + MinkowskiEngine/MinkowskiSparseTensor.py | 25 ++++++++++++++++++------ tests/python/dense.py | 10 ++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eeffc2cd..f2f81a4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - Fix for GPU `coo_spmm` when nnz == 0 - Fix MinkowskiInterpolationGPU for invalid samples (issue #383) - gradcheck wrap func 1.9 +- Handle `dense` for an empty tensor (issue #384) ## [0.5.4] diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index 3c556351..b80d5f49 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -199,8 +199,12 @@ def __init__( assert isinstance(coordinates, torch.Tensor) if coordinate_map_key is not None: assert isinstance(coordinate_map_key, CoordinateMapKey) - assert coordinate_manager is not None, "Must provide coordinate_manager if coordinate_map_key is provided" - assert coordinates is None, "Must not provide coordinates if coordinate_map_key is provided" + assert ( + coordinate_manager is not None + ), "Must provide coordinate_manager if coordinate_map_key is provided" + assert ( + coordinates is None + ), "Must not provide coordinates if coordinate_map_key is provided" if coordinate_manager is not None: assert isinstance(coordinate_manager, CoordinateManager) if coordinates is None and ( @@ -404,7 +408,6 @@ def torch_sparse_Tensor(coords, feats, size=None): coords = self.C coords, batch_indices = coords[:, 1:], coords[:, 0] - # TODO, batch first if min_coords is None: min_coords, _ = coords.min(0, keepdim=True) elif min_coords.ndim == 1: @@ -491,13 +494,21 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True): if shape[1] != self._F.size(1): shape = torch.Size([shape[0], self._F.size(1), *[s for s in shape[2:]]]) + # Exception handling for empty tensor + if self.__len__() == 0: + assert shape is not None, "shape is required to densify an empty tensor" + return ( + torch.zeros(shape, dtype=self.dtype, device=self.device), + torch.zeros(self._D, dtype=torch.int32, device=self.device), + self.tensor_stride, + ) + # Use int tensor for all operations tensor_stride = torch.IntTensor(self.tensor_stride).to(self.device) # New coordinates batch_indices = self.C[:, 0] - # TODO, batch first if min_coordinate is None: min_coordinate, _ = self.C.min(0, keepdim=True) min_coordinate = min_coordinate[:, 1:] @@ -528,9 +539,11 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True): nchannels = self.F.size(1) if shape is None: size = coords.max(0)[0] + 1 - shape = torch.Size([batch_indices.max() + 1, nchannels, *size.cpu().numpy()]) + shape = torch.Size( + [batch_indices.max() + 1, nchannels, *size.cpu().numpy()] + ) - dense_F = torch.zeros(shape, dtype=self.F.dtype, device=self.F.device) + dense_F = torch.zeros(shape, dtype=self.dtype, device=self.device) tcoords = coords.t().long() batch_indices = batch_indices.long() diff --git a/tests/python/dense.py b/tests/python/dense.py index 6fcc0f94..2e64fbd2 100644 --- a/tests/python/dense.py +++ b/tests/python/dense.py @@ -94,6 +94,16 @@ def test(self): print(feats.grad) + def test_empty(self): + x = torch.zeros(4, 1, 34, 34) + to_dense = ME.MinkowskiToDenseTensor(x.shape) + + # Convert to sparse data + sparse_data = ME.to_sparse(x) + dense_data = to_dense(sparse_data) + + self.assertEqual(dense_data.shape, x.shape) + class TestDenseToSparse(unittest.TestCase): def test(self):