Skip to content

Commit

Permalink
sparse tensor with optional allow duplication, better quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 24, 2019
1 parent 64518c1 commit d5368a2
Show file tree
Hide file tree
Showing 15 changed files with 385 additions and 112 deletions.
14 changes: 9 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Change Log


## [nightly] - 2019-12-15
## [nightly] - 2019-12-24

- Synchronized Batch Norm: `ME.MinkowskiSyncBatchNorm`
- `ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm` converts a MinkowskiNetwork automatically to use synched batch norm.
Expand All @@ -10,20 +10,24 @@
- Update GIL release
- Minor error fixes on `examples/modelnet40.py`
- CoordsMap size initialization updates
- Added MinkowskiUnion
- Updated MinkowskiUnion, MinkowskiPruning docs
- Add MinkowskiUnion
- Update MinkowskiUnion, MinkowskiPruning docs
- Use cudaMalloc instead of `at::Tensor` for GPU memory management for illegal memory access, invalid arg.
- Region hypercube iterator with even numbered kernel
- Fix global reduction in-out map with non contiguous batch indices
- GlobalPooling with torch reduction
- GlobalPoolingMode with index select and sparse backbone
- If batch size == 1, skip the backend
- Added CoordsManager functions
- Add CoordsManager functions
- `get_batch_size`
- `get_batch_indices`
- `set_origin_coords_key`
- Updated CoordsManager function `get_row_indices_per_batch` to return a list of `torch.LongTensor` for mapping indices. The corresponding batch indices is accessible by `get_batch_indices`.
- Update CoordsManager function `get_row_indices_per_batch` to return a list of `torch.LongTensor` for mapping indices. The corresponding batch indices is accessible by `get_batch_indices`.
- Update `MinkowskiBroadcast`, `MinkowskiBroadcastConcatenation` to use row indices per batch (`getRowIndicesPerBatch`)
- Update `SparseTensor`
- `allow_duplicate_coords` argument support
- update documentation, add unittest
- Add `quantize_th`, `quantize_label_th`


## [0.3.1] - 2019-12-15
Expand Down
6 changes: 3 additions & 3 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def initialize(self,
coords_key,
force_creation=False,
force_remap=False,
allow_duplicates=False):
allow_duplicate_coords=False):
assert isinstance(coords_key, CoordsKey)
mapping = torch.IntTensor()
mapping = torch.LongTensor()
self.CPPCoordsManager.initializeCoords(coords, mapping,
coords_key.CPPCoordsKey,
force_creation, force_remap,
allow_duplicates)
allow_duplicate_coords)
return mapping

def stride(self, coords_key, stride, force_creation=False):
Expand Down
42 changes: 38 additions & 4 deletions MinkowskiEngine/SparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ class SparseTensor():
.. warning::
From the version 0.3, we will put the batch indices on the first column
From the version 0.4, we will put the batch indices on the first column
to be consistent with the standard neural network packages.
Please use :attr:`MinkowskiEngine.utils.batched_coordinates` or
:attr:`MinkowskiEngine.utils.sparse_collate` when creating coordinates
to make your code to generate batched coordinates automatically that are
compatible with the latest version of Minkowski Engine.
.. math::
\mathbf{C} = \begin{bmatrix}
Expand All @@ -83,6 +88,7 @@ def __init__(self,
coords_key=None,
coords_manager=None,
force_creation=False,
allow_duplicate_coords=False,
tensor_stride=1):
r"""
Expand Down Expand Up @@ -110,6 +116,19 @@ def __init__(self,
handled automatically and you do not need to use this. When you use
it, make sure you understand what you are doing.
:attr:`force_creation` (:attr:`bool`): Force creation of the
coordinates. This allows generating a new set of coordinates even
when there exists another set of coordinates with the same
tensor stride. This could happen when you manually feed the same
attr:`coords_manager`.
:attr:`allow_duplicate_coords` (:attr:`bool`): Allow duplicate
coordinates when creating the sparse tensor. Internally, it will
generate a new unique set of coordinates and use features of at the
corresponding unique coordinates. In general, setting
`allow_duplicate_coords=True` is not recommended as it could hide
obvious errors in your data loading and preprocessing steps.
:attr:`tensor_stride` (:attr:`int`, :attr:`list`,
:attr:`numpy.array`, or :attr:`tensor.Tensor`): The tensor stride
of the current sparse tensor. By default, it is 1.
Expand Down Expand Up @@ -153,14 +172,29 @@ def __init__(self,
assert coords is not None, "Initial coordinates must be given"
D = coords.size(1) - 1
coords_manager = CoordsManager(D=D)
self.mapping = coords_manager.initialize(coords, coords_key)
self.mapping = coords_manager.initialize(
coords,
coords_key,
force_creation=force_creation,
force_remap=allow_duplicate_coords,
allow_duplicate_coords=allow_duplicate_coords)
if len(self.mapping) > 0:
coords = coords[self.mapping]
feats = feats[self.mapping]
else:
assert isinstance(coords_manager, CoordsManager)

if not coords_key.isKeySet():
assert coords is not None
self.mapping = coords_manager.initialize(
coords, coords_key, force_creation=force_creation)
coords,
coords_key,
force_creation=force_creation,
force_remap=allow_duplicate_coords,
allow_duplicate_coords=allow_duplicate_coords)
if len(self.mapping) > 0:
coords = coords[self.mapping]
feats = feats[self.mapping]

self._F = feats.contiguous()
self._C = coords
Expand Down Expand Up @@ -249,7 +283,7 @@ def __repr__(self):
return self.__class__.__name__ + '(' + os.linesep \
+ ' Coords=' + str(self.C) + os.linesep \
+ ' Feats=' + str(self.F) + os.linesep \
+ ' coords_key=' + str(self.coords_key) + os.linesep \
+ ' coords_key=' + str(self.coords_key) \
+ ' tensor_stride=' + str(self.coords_key.getTensorStride()) + os.linesep \
+ ' coords_man=' + str(self.coords_man) + ')'

Expand Down
50 changes: 30 additions & 20 deletions MinkowskiEngine/utils/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def batched_coordinates(coords):
.. warning::
From v0.3, the batch index will be prepended before all coordinates.
From v0.4, the batch index will be prepended before all coordinates.
"""
assert isinstance(
Expand Down Expand Up @@ -90,49 +90,59 @@ def sparse_collate(coords, feats, labels=None, is_double=False):
"""
use_label = False if labels is None else True
coords_batch, feats_batch, labels_batch = [], [], []
assert isinstance(coords, collections.abc.Sequence), \
"The coordinates must be a sequence of arrays or tensors."
assert isinstance(feats, collections.abc.Sequence), \
"The features must be a sequence of arrays or tensors."
if use_label:
assert isinstance(labels, collections.abc.Sequence), \
"The labels must be a sequence of arrays or tensors."

N = np.array([len(cs) for cs in coords]).sum()
Nf = np.array([len(fs) for fs in feats]).sum()
assert N == Nf, f"Coordinate length {N} != Feature length {Nf}"

batch_id = 0
s = 0 # start index
bcoords = torch.IntTensor(N, D + 1) # uninitialized batched coords
for coord, feat in zip(coords, feats):
if isinstance(coord, np.ndarray):
coord = torch.from_numpy(coord)
else:
assert isinstance(
coord, torch.Tensor
), "Coords must be of type numpy.ndarray or torch.Tensor"
assert isinstance( coord, torch.Tensor), \
"Coords must be of type numpy.ndarray or torch.Tensor"
coord = coord.int()

if isinstance(feat, np.ndarray):
feat = torch.from_numpy(feat)
else:
assert isinstance(
feat, torch.Tensor
), "Features must be of type numpy.ndarray or torch.Tensor"
assert isinstance( feat, torch.Tensor), \
"Features must be of type numpy.ndarray or torch.Tensor"
feat = feat.double() if is_double else feat.float()

# Batched coords
num_points = coord.shape[0]
coords_batch.append(
torch.cat((coord, torch.ones(num_points, 1).int() * batch_id), 1))

# Features
feats_batch.append(feat)

# Labels
if use_label:
label = labels[batch_id]
if isinstance(label, np.ndarray):
label = torch.from_numpy(label)
else:
assert isinstance(
label, torch.Tensor
), "labels must be of type numpy.ndarray or torch.Tensor"
assert isinstance(label, torch.Tensor), \
"labels must be of type numpy.ndarray or torch.Tensor"
labels_batch.append(label)

# Batched coords
cn = coord.shape[0]
bcoords[s:s + cn, :D] = coord
bcoords[s:s + cn, D] = b

# Features
feats_batch.append(feat)

# Post processing steps
batch_id += 1
s += cn

# Concatenate all lists
coords_batch = torch.cat(coords_batch, 0).int()
feats_batch = torch.cat(feats_batch, 0)
if use_label:
labels_batch = torch.cat(labels_batch, 0)
Expand Down
78 changes: 58 additions & 20 deletions MinkowskiEngine/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,37 @@ def ravel_hash_vec(arr):
return keys


def quantize(coords):
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
"Invalid coords type"
if isinstance(coords, np.ndarray):
assert coords.dtype == np.int32, f"Invalid coords type {coords.dtype} != np.int32"
return MEB.quantize_np(coords.astype(np.int32))
else:
# Type check done inside
return MEB.quantize_th(coords.int())


def quantize_label(coords, labels, ignore_label):
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
"Invalid coords type"
if isinstance(coords, np.ndarray):
assert isinstance(labels, np.ndarray), "Invalid label type"
assert coords.dtype == np.int32, "Invalid coords type"
assert labels.dtype == np.int32, "Invalid coords type"
return MEB.quantize_label_np(coords, labels, ignore_label)
else:
assert isinstance(labels, torch.Tensor), "Invalid label type"
# Type check done inside
return MEB.quantize_label_th(coords, labels, ignore_label)


def sparse_quantize(coords,
feats=None,
labels=None,
ignore_label=255,
return_index=False,
quantization_size=1):
quantization_size=None):
r"""Given coordinates, and features (optionally labels), the function
generates quantized (voxelized) coordinates.
Expand Down Expand Up @@ -95,6 +120,9 @@ def sparse_quantize(coords,
Please check `examples/indoor.py` for the usage.
"""
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
'Coords must be either np.array or torch.Tensor.'

use_label = labels is not None
use_feat = feats is not None

Expand All @@ -112,29 +140,36 @@ def sparse_quantize(coords,
if use_label:
assert coords.shape[0] == len(labels)

# Quantize the coordinates
dimension = coords.shape[1]
if isinstance(quantization_size, (Sequence, np.ndarray, torch.Tensor)):
assert len(
quantization_size
) == dimension, "Quantization size and coordinates size mismatch."
quantization_size = np.array([i for i in quantization_size])
discrete_coords = np.floor(coords / quantization_size)
elif np.isscalar(quantization_size): # Assume that it is a scalar

if quantization_size == 1:
discrete_coords = coords
else:
quantization_size = np.array(
[quantization_size for i in range(dimension)])
# Quantize the coordinates
if quantization_size is not None:
if isinstance(quantization_size, (Sequence, np.ndarray, torch.Tensor)):
assert len(
quantization_size
) == dimension, "Quantization size and coordinates size mismatch."
quantization_size = np.array([i for i in quantization_size])
discrete_coords = np.floor(coords / quantization_size)
elif np.isscalar(quantization_size): # Assume that it is a scalar

if quantization_size == 1:
discrete_coords = coords
else:
discrete_coords = np.floor(coords / quantization_size)
else:
raise ValueError('Not supported type for quantization_size.')
else:
raise ValueError('Not supported type for quantization_size.')
discrete_coords = coords

discrete_coords = np.floor(discrete_coords)
if isinstance(coords, np.ndarray):
discrete_coords = discrete_coords.astype(np.int32)
else:
discrete_coords = discrete_coords.int()

# Return values accordingly
if use_label:
mapping, colabels = MEB.quantize_label(discrete_coords, labels,
ignore_label)
mapping, colabels = quantize_label(discrete_coords, labels,
ignore_label)

if return_index:
return mapping, colabels
Expand All @@ -145,7 +180,7 @@ def sparse_quantize(coords,
return discrete_coords[mapping], colabels

else:
mapping = MEB.quantize(discrete_coords)
mapping = quantize(discrete_coords)
if len(mapping) > 0:
if return_index:
return mapping
Expand All @@ -157,7 +192,10 @@ def sparse_quantize(coords,

else:
if return_index:
return np.arange(len(discrete_coords))
if isinstance(discrete_coords, np.ndarray):
return np.arange(len(discrete_coords))
else:
return torch.range(len(discrete_coords), dtype=torch.long)
else:
if use_feat:
return discrete_coords, feats
Expand Down
12 changes: 12 additions & 0 deletions docs/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ sparse_quantize
.. autofunction:: MinkowskiEngine.utils.sparse_quantize


batched_coordinates
-------------------

.. autofunction:: MinkowskiEngine.utils.batched_coordinates


sparse_collate
--------------

.. autofunction:: MinkowskiEngine.utils.sparse_collate


SparseCollation
---------------

Expand Down
9 changes: 7 additions & 2 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,14 @@ UnionBackwardGPU(at::Tensor grad_out_feat, vector<py::object> py_in_coords_keys,
* Quantization
*************************************/
vector<int>
quantize(py::array_t<int, py::array::c_style | py::array::forcecast> coords);
quantize_np(py::array_t<int, py::array::c_style | py::array::forcecast> coords);

vector<py::array> quantize_label(
vector<py::array> quantize_label_np(
py::array_t<int, py::array::c_style | py::array::forcecast> coords,
py::array_t<int, py::array::c_style | py::array::forcecast> labels,
int invalid_label);

at::Tensor quantize_th(at::Tensor coords);

vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
int invalid_label);
Loading

0 comments on commit d5368a2

Please sign in to comment.