Skip to content

Commit

Permalink
Templated coords map & sparse tensor quantization mode
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit da5a1bbc6177813466b92c813926d077f5f1398f
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Fri Apr 10 01:15:12 2020 -0700

    sparse tensor with automatic feature average for duplicate coords

commit 8534888
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Wed Apr 8 01:55:47 2020 -0700

    coords map instantiations and fixes

commit 28fbde9
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sat Apr 4 01:26:36 2020 -0700

    Coords MapType

commit 3df3a03
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Thu Apr 2 18:43:45 2020 -0700

    restructure coordsmap init

commit 496d8f8
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Thu Apr 2 14:50:28 2020 -0700

    quantization
  • Loading branch information
chrischoy committed Apr 10, 2020
1 parent 9d8dfae commit fb95f93
Show file tree
Hide file tree
Showing 33 changed files with 2,066 additions and 941 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Change Log

## [master] - 2020-04-01
## [master] - 2020-04-08

### Changed

Expand All @@ -11,6 +11,10 @@
- Add `coordinates_and_features_at(batch_index)` function in the SparseTensor class.
- Add `MinkowskiChannelwiseConvolution` (Issue #92)
- Update `MinkowskiPruning` to generate an empty sparse tensor as output (Issue #102)
- Add `return_index` for `sparse_quantize`
- Templated CoordsManager for coords to int and coords to vector classes
- Sparse tensor quantization mode
- Features at duplicated coordinates will be averaged automatically with `quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE`


## [0.4.2] - 2020-03-13
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
-DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=$(EXTENSION_NAME) \
-D_GLIBCXX_USE_CXX11_ABI=$(WITH_ABI)

CXXFLAGS += -fopenmp -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
CXXFLAGS += -fopenmp -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
LINKFLAGS += -pthread -fPIC $(WARNINGS) -Wl,-rpath=$(PYTHON_LIB_DIR) -Wl,--no-as-needed -Wl,--sysroot=/
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
$(foreach library,$(LIBRARIES),-l$(library))
Expand Down
14 changes: 8 additions & 6 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ def initialize(self,
coords_key: CoordsKey,
force_creation: bool = False,
force_remap: bool = False,
allow_duplicate_coords: bool = False) -> torch.LongTensor:
allow_duplicate_coords: bool = False,
return_inverse: bool = False) -> torch.LongTensor:
assert isinstance(coords_key, CoordsKey)
mapping = torch.LongTensor()
self.CPPCoordsManager.initializeCoords(coords, mapping,
unique_index = torch.LongTensor()
inverse_mapping = torch.LongTensor()
self.CPPCoordsManager.initializeCoords(coords, unique_index, inverse_mapping,
coords_key.CPPCoordsKey,
force_creation, force_remap,
allow_duplicate_coords)
return mapping
allow_duplicate_coords, return_inverse)
return unique_index, inverse_mapping

def create_coords_key(self,
coords: torch.IntTensor,
Expand All @@ -100,7 +102,7 @@ def create_coords_key(self,
allow_duplicate_coords: bool = False) -> CoordsKey:
coords_key = CoordsKey(self.D)
coords_key.setTensorStride(tensor_stride)
mapping = self.initialize(
unique_index, inverse_mapping = self.initialize(
coords,
coords_key,
force_creation=True,
Expand Down
109 changes: 76 additions & 33 deletions MinkowskiEngine/SparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from Common import convert_to_int_list
from MinkowskiCoords import CoordsKey, CoordsManager
import MinkowskiEngineBackend as MEB


class SparseTensorOperationMode(Enum):
Expand All @@ -43,6 +44,15 @@ class SparseTensorOperationMode(Enum):
SHARE_COORDS_MANAGER = 1


class SparseTensorQuantizationMode(Enum):
"""
RANDOM_SUBSAMPLE: Subsample one coordinate per each quantization block randomly.
UNWEIGHTED_AVERAGE: average all features within a quantization block equally.
"""
RANDOM_SUBSAMPLE = 0
UNWEIGHTED_AVERAGE = 1


_sparse_tensor_operation_mode = SparseTensorOperationMode.SEPARATE_COORDS_MANAGER
_global_coords_man = None
COORDS_MAN_DIFFERENT_ERROR = "SparseTensors must share the same coordinate manager for this operation. Please refer to the SparseTensor creation API (https://stanfordvl.github.io/MinkowskiEngine/sparse_tensor.html) to share the coordinate manager, or set the sparse tensor operation mode with `set_sparse_tensor_operation_mode` to share it by default."
Expand All @@ -51,7 +61,7 @@ class SparseTensorOperationMode(Enum):

def set_sparse_tensor_operation_mode(operation_mode: SparseTensorOperationMode):
assert isinstance(operation_mode, SparseTensorOperationMode), \
f"Input must be an instance of SparseTensorOperationMode not {operation_mode}"
f"Input must be an instance of SparseTensorOperationMode not {operation_mode}"
global _sparse_tensor_operation_mode
_sparse_tensor_operation_mode = operation_mode

Expand Down Expand Up @@ -127,14 +137,16 @@ class SparseTensor():
"""

def __init__(self,
feats,
coords=None,
coords_key=None,
coords_manager=None,
force_creation=False,
allow_duplicate_coords=False,
tensor_stride=1):
def __init__(
self,
feats,
coords=None,
coords_key=None,
coords_manager=None,
force_creation=False,
allow_duplicate_coords=False,
quantization_mode=SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
tensor_stride=1):
r"""
Args:
Expand All @@ -150,8 +162,8 @@ def __init__(self,
:attr:`coords_key` (:attr:`MinkowskiEngine.CoordsKey`): When the
coordinates are already cached in the MinkowskiEngine, we could
reuse the same coordinates by simply providing the coordinate hash
key. In most case, this process is done automatically. If you
provide one, make sure you understand what you are doing.
key. In most case, this process is done automatically. When you
provide a `coords_key`, all other arguments will be be ignored.
:attr:`coords_manager` (:attr:`MinkowskiEngine.CoordsManager`): The
MinkowskiEngine creates a dynamic computation graph and all
Expand All @@ -177,6 +189,11 @@ def __init__(self,
<https://stanfordvl.github.io/MinkowskiEngine/demo/training.html>`_
for more details.
:attr:`quantizatino_mode`
(:attr:`MinkowskiEngine.SparseTensorQuantizationMode`): Defines the
quantization method and how to define features of a sparse tensor.
Please refer to :attr:`SparseTensorQuantizationMode` for details.
:attr:`tensor_stride` (:attr:`int`, :attr:`list`,
:attr:`numpy.array`, or :attr:`tensor.Tensor`): The tensor stride
of the current sparse tensor. By default, it is 1.
Expand Down Expand Up @@ -219,10 +236,13 @@ def __init__(self,
coords = coords.cpu()

assert feats.shape[0] == coords.shape[0], \
"Number of rows in features and coordinates do not match."
"The number of rows in features and coordinates do not match."

coords = coords.contiguous()

##########################
# Setup CoordsManager
##########################
if coords_manager is None:
# If set to share the coords man, use the global coords man
global _sparse_tensor_operation_mode, _global_coords_man
Expand All @@ -234,30 +254,53 @@ def __init__(self,
assert coords is not None, "Initial coordinates must be given"
coords_manager = CoordsManager(D=coords.size(1) - 1)

if not coords_key.isKeySet():
self.mapping = coords_manager.initialize(
coords,
coords_key,
force_creation=force_creation,
force_remap=allow_duplicate_coords,
allow_duplicate_coords=allow_duplicate_coords)
if len(self.mapping) > 0:
coords = coords[self.mapping]
feats = feats[self.mapping]
else:
assert isinstance(coords_manager, CoordsManager)

if not coords_key.isKeySet():
assert coords is not None
self.mapping = coords_manager.initialize(
coords,
coords_key,
force_creation=force_creation,
force_remap=allow_duplicate_coords,
allow_duplicate_coords=allow_duplicate_coords)
if len(self.mapping) > 0:
coords = coords[self.mapping]
feats = feats[self.mapping]
##########################
# Initialize coords
##########################
if not coords_key.isKeySet() and coords is not None and len(coords) > 0:
assert isinstance(quantization_mode, SparseTensorQuantizationMode)

if quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE:
force_remap = True
return_inverse = False
elif quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE:
force_remap = True
return_inverse = True

self.unique_index, self.inverse_mapping = coords_manager.initialize(
coords,
coords_key,
force_creation=force_creation,
force_remap=force_remap,
allow_duplicate_coords=allow_duplicate_coords,
return_inverse=return_inverse)

if quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE:
feats = MEB.quantization_average_features(
feats, torch.arange(len(feats)), self.inverse_mapping,
len(self.unique_index), 0)
elif force_remap:
assert len(self.unique_index) > 0
self._CC = coords
self._CF = feats
coords = coords[self.unique_index]
feats = feats[self.unique_index]

elif coords is not None: # empty / invalid coords
assert isinstance(coords, torch.IntTensor)
assert coords.ndim == 2
coords_manager.initialize(
coords,
coords_key,
force_creation=force_creation,
force_remap=False,
allow_duplicate_coords=False,
return_inverse=False)
elif coords_key is not None:
assert coords_key.isKeySet()

self._F = feats.contiguous()
self._C = coords
Expand Down
2 changes: 1 addition & 1 deletion MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# Must be imported first to load all required shared libs
import torch

from SparseTensor import SparseTensor, SparseTensorOperationMode, \
from SparseTensor import SparseTensor, SparseTensorOperationMode, SparseTensorQuantizationMode, \
set_sparse_tensor_operation_mode, sparse_tensor_operation_mode, clear_global_coords_man

from Common import RegionType, convert_to_int_tensor, convert_region_type, \
Expand Down
19 changes: 11 additions & 8 deletions examples/indoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--file_name', type=str, default='1.ply')
parser.add_argument('--weights', type=str, default='weights.pth')
parser.add_argument('--use_cpu', action='store_true')

CLASS_LABELS = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
Expand Down Expand Up @@ -102,11 +103,8 @@ def load_file(file_name, voxel_size):
pcd = o3d.io.read_point_cloud(file_name)
coords = np.array(pcd.points)
feats = np.array(pcd.colors)

quantized_coords = np.floor(coords / voxel_size)
inds = ME.utils.sparse_quantize(quantized_coords, return_index=True)

return quantized_coords[inds], feats[inds], pcd
return quantized_coords, feats, pcd


def generate_input_sparse_tensor(file_name, voxel_size=0.05):
Expand All @@ -121,8 +119,9 @@ def generate_input_sparse_tensor(file_name, voxel_size=0.05):

if __name__ == '__main__':
config = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device = torch.device('cuda' if (
torch.cuda.is_available() and not config.use_cpu) else 'cpu')
print(f"Using {device}")
# Define a model and load the weights
model = MinkUNet34C(3, 20).to(device)
model_dict = torch.load(config.weights)
Expand All @@ -137,7 +136,10 @@ def generate_input_sparse_tensor(file_name, voxel_size=0.05):

# Feed-forward pass and get the prediction
sinput = ME.SparseTensor(
features, coords=coordinates).to(device)
features,
coords=coordinates,
quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
).to(device)
soutput = model(sinput)

# Feed-forward pass and get the prediction
Expand All @@ -151,7 +153,8 @@ def generate_input_sparse_tensor(file_name, voxel_size=0.05):
pred_pcd = o3d.geometry.PointCloud()
# Map color
colors = np.array([SCANNET_COLOR_MAP[VALID_CLASS_IDS[l]] for l in pred])
pred_pcd.points = o3d.utility.Vector3dVector(coordinates[batch_index] * 0.02)
pred_pcd.points = o3d.utility.Vector3dVector(coordinates[batch_index] *
0.02)
pred_pcd.colors = o3d.utility.Vector3dVector(colors / 255)

# Move the original point cloud
Expand Down
Loading

0 comments on commit fb95f93

Please sign in to comment.