Skip to content

Commit

Permalink
conv kernel size 1 with explicit coordinates (fix #203)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent b7d4c34 commit 75e9e37
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 42 deletions.
4 changes: 2 additions & 2 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ def forward(
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

if self.use_mm and coordinates is None:
if self.use_mm:
# If the kernel_size == 1, the convolution is simply a matrix
# multiplication
out_coordinate_map_key = _get_coordinate_map_key(input, coordinates)
outfeat = input.F.mm(self.kernel)
out_coordinate_map_key = input.coordinate_map_key
else:
# Get a new coordinate_map_key or extract one from the coords
out_coordinate_map_key = _get_coordinate_map_key(input, coordinates)
Expand Down
12 changes: 9 additions & 3 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch
import warnings

from MinkowskiCommon import StrideType
from MinkowskiCommon import convert_to_int_list, StrideType
from MinkowskiEngineBackend._C import CoordinateMapKey
from MinkowskiTensor import (
SparseTensorQuantizationMode,
Expand Down Expand Up @@ -493,9 +493,15 @@ def _get_coordinate_map_key(
if coordinates is not None:
assert isinstance(coordinates, (CoordinateMapKey, torch.Tensor, SparseTensor))
if isinstance(coordinates, torch.Tensor):
coordinate_map_key = input._manager.create_coordinate_map_key(
coordinates, tensor_stride=tensor_stride
assert coordinates.ndim == 2
coordinate_map_key = CoordinateMapKey(
convert_to_int_list(tensor_stride, coordinates.size(1) - 1), ""
)

(
coordinate_map_key,
(unique_index, inverse_mapping),
) = input._manager.insert_and_map(coordinates, *coordinate_map_key.get_key())
elif isinstance(coordinates, SparseTensor):
coordinate_map_key = coordinates.coordinate_map_key
else: # CoordinateMapKey type due to the previous assertion
Expand Down
14 changes: 7 additions & 7 deletions MinkowskiEngine/MinkowskiTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def get_device(self):
return self._F.get_device()

def _is_same_key(self, other):
assert isinstance(other, SparseTensor)
assert isinstance(other, self.__class__)
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR
assert (
self.coordinate_map_key == other.coordinate_map_key
Expand Down Expand Up @@ -587,12 +587,12 @@ def __idiv__(self, other):
return self

def _binary_functor(self, other, binary_fn):
assert isinstance(other, (SparseTensor, torch.Tensor))
if isinstance(other, SparseTensor):
assert isinstance(other, (self.__class__, torch.Tensor))
if isinstance(other, self.__class__):
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR

if self.coordinate_map_key == other.coordinate_map_key:
return SparseTensor(
return self.__class__(
binary_fn(self._F, other.F),
coordinate_map_key=self.coordinate_map_key,
coordinate_manager=self._manager,
Expand All @@ -609,11 +609,11 @@ def _binary_functor(self, other, binary_fn):
)
out_F[outs[0]] = self._F[ins[0]]
out_F[outs[1]] = binary_fn(out_F[outs[1]], other._F[ins[1]])
return SparseTensor(
return self.__class__(
out_F, coordinate_map_key=out_key, coords_manager=self._manager
)
else: # when it is a torch.Tensor
return SparseTensor(
return self.__class__(
binary_fn(self._F, other),
coordinate_map_key=self.coordinate_map_key,
coordinate_manager=self._manager,
Expand Down Expand Up @@ -659,7 +659,7 @@ def __truediv__(self, other):
return self._binary_functor(other, lambda x, y: x / y)

def __power__(self, power):
return SparseTensor(
return self.__class__(
self._F ** power,
coordinate_map_key=self.coordinate_map_key,
coordinate_manager=self._manager,
Expand Down
73 changes: 43 additions & 30 deletions MinkowskiEngine/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def fnv_hash_vec(arr):
# Floor first for negative coordinates
arr = arr.copy()
arr = arr.astype(np.uint64, copy=False)
hashed_arr = np.uint64(14695981039346656037) * \
np.ones(arr.shape[0], dtype=np.uint64)
hashed_arr = np.uint64(14695981039346656037) * np.ones(
arr.shape[0], dtype=np.uint64
)
for j in range(arr.shape[1]):
hashed_arr *= np.uint64(1099511628211)
hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
Expand Down Expand Up @@ -87,37 +88,47 @@ def quantize(coords):
>>> print(coords[unique_map[inverse_map]] == coords) # True, ..., True
"""
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
"Invalid coords type"
assert isinstance(coords, np.ndarray) or isinstance(
coords, torch.Tensor
), "Invalid coords type"
if isinstance(coords, np.ndarray):
assert coords.dtype == np.int32, f"Invalid coords type {coords.dtype} != np.int32"
assert (
coords.dtype == np.int32
), f"Invalid coords type {coords.dtype} != np.int32"
return MEB.quantize_np(coords.astype(np.int32))
else:
# Type check done inside
return MEB.quantize_th(coords.int())


def quantize_label(coords, labels, ignore_label):
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
"Invalid coords type"
assert isinstance(coords, np.ndarray) or isinstance(
coords, torch.Tensor
), "Invalid coords type"
if isinstance(coords, np.ndarray):
assert isinstance(labels, np.ndarray)
assert coords.dtype == np.int32, f"Invalid coords type {coords.dtype} != np.int32"
assert labels.dtype == np.int32, f"Invalid label type {labels.dtype} != np.int32"
assert (
coords.dtype == np.int32
), f"Invalid coords type {coords.dtype} != np.int32"
assert (
labels.dtype == np.int32
), f"Invalid label type {labels.dtype} != np.int32"
return MEB.quantize_label_np(coords, labels, ignore_label)
else:
assert isinstance(labels, torch.Tensor)
# Type check done inside
return MEB.quantize_label_th(coords, labels.int(), ignore_label)


def sparse_quantize(coords,
feats=None,
labels=None,
ignore_label=-100,
return_index=False,
return_inverse=False,
quantization_size=None):
def sparse_quantize(
coords,
feats=None,
labels=None,
ignore_label=-100,
return_index=False,
return_inverse=False,
quantization_size=None,
):
r"""Given coordinates, and features (optionally labels), the function
generates quantized (voxelized) coordinates.
Expand Down Expand Up @@ -176,14 +187,17 @@ def sparse_quantize(coords,
"""
assert isinstance(coords, np.ndarray) or isinstance(coords, torch.Tensor), \
'Coords must be either np.array or torch.Tensor.'
assert isinstance(coords, np.ndarray) or isinstance(
coords, torch.Tensor
), "Coords must be either np.array or torch.Tensor."

use_label = labels is not None
use_feat = feats is not None

assert coords.ndim == 2, \
"The coordinates must be a 2D matrix. The shape of the input is " + str(coords.shape)
assert coords.ndim == 2, (
"The coordinates must be a 2D matrix. The shape of the input is "
+ str(coords.shape)
)

if return_inverse:
assert return_index, "return_reverse must be set with return_index"
Expand All @@ -199,9 +213,9 @@ def sparse_quantize(coords,
# Quantize the coordinates
if quantization_size is not None:
if isinstance(quantization_size, (Sequence, np.ndarray, torch.Tensor)):
assert len(
quantization_size
) == dimension, "Quantization size and coordinates size mismatch."
assert (
len(quantization_size) == dimension
), "Quantization size and coordinates size mismatch."
if isinstance(coords, np.ndarray):
quantization_size = np.array([i for i in quantization_size])
discrete_coords = np.floor(coords / quantization_size)
Expand All @@ -216,7 +230,7 @@ def sparse_quantize(coords,
else:
discrete_coords = np.floor(coords / quantization_size)
else:
raise ValueError('Not supported type for quantization_size.')
raise ValueError("Not supported type for quantization_size.")
else:
discrete_coords = coords

Expand All @@ -228,24 +242,23 @@ def sparse_quantize(coords,

# Return values accordingly
if use_label:
mapping, colabels = quantize_label(discrete_coords, labels,
ignore_label)
unique_map, colabels = quantize_label(discrete_coords, labels, ignore_label)

if return_index:
return discrete_coords[unique_map], mapping, colabels
return discrete_coords[unique_map], unique_map, colabels
else:
if use_feat:
return discrete_coords[mapping], feats[mapping], colabels
return discrete_coords[unique_map], feats[unique_map], colabels
else:
return discrete_coords[mapping], colabels
return discrete_coords[unique_map], colabels

else:
unique_map, inverse_map = quantize(discrete_coords)
if return_index:
if return_inverse:
return discrete_coords[unique_map], unique_map, inverse_map
else:
return discrete_coords[unique_map], uunique_map
return discrete_coords[unique_map], unique_map
else:
if use_feat:
return discrete_coords[unique_map], feats[unique_map]
Expand Down

0 comments on commit 75e9e37

Please sign in to comment.