Skip to content

Commit

Permalink
sparse_quantize with labels supports inverse map (fix #271)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 30, 2020
1 parent b6ac373 commit 84c52b7
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 65 deletions.
12 changes: 10 additions & 2 deletions MinkowskiEngine/MinkowskiPooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,13 @@ class MinkowskiPoolingTranspose(MinkowskiPoolingBase):
"""

def __init__(
self, kernel_size, stride, dilation=1, kernel_generator=None, expand_coordinates=False, dimension=None
self,
kernel_size,
stride,
dilation=1,
kernel_generator=None,
expand_coordinates=False,
dimension=None,
):
r"""a high-dimensional unpooling layer for sparse tensors.
Expand Down Expand Up @@ -624,7 +630,9 @@ def backward(ctx, grad_out_feat):
class MinkowskiGlobalPooling(MinkowskiModuleBase):
r"""Pool all input features to one output."""

def __init__(self, mode: PoolingMode = PoolingMode.GLOBAL_AVG_POOLING_DEFAULT):
def __init__(
self, mode: PoolingMode = PoolingMode.GLOBAL_AVG_POOLING_PYTORCH_INDEX
):
r"""Reduces sparse coords into points at origin, i.e. reduce each point
cloud into a point at the origin, returning batch_size number of points
[[0, 0, ..., 0], [0, 0, ..., 1],, [0, 0, ..., 2]] where the last elem
Expand Down
63 changes: 29 additions & 34 deletions MinkowskiEngine/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def sparse_quantize(
elif "cuda" in device:
manager = MEB.CoordinateMapManagerGPU_c10()
else:
raise ValueError("Invalid device")
raise ValueError("Invalid device. Only `cpu` or `cuda` supported.")

# Return values accordingly
if use_label:
Expand All @@ -277,32 +277,29 @@ def sparse_quantize(
discrete_coordinates, tensor_stride, ""
)

assert (
device == "cpu"
), "CUDA accelerated quantization with labels not supported currently"
# assert (
# device == "cpu"
# ), "CUDA accelerated quantization with labels not supported currently"

if return_maps_only:
return unique_map
if return_inverse:
return unique_map, inverse_map
else:
return unique_map

return_args = [discrete_coordinates[unique_map]]
if use_feat:
return_args.append(features[unique_map])
return_args.append(labels[unique_map])
if return_index:
if use_feat:
return (
discrete_coordinates[unique_map],
features[unique_map],
labels[unique_map],
unique_map,
)
else:
return discrete_coordinates[unique_map], labels[unique_map], unique_map
return_args.append(unique_map)
if return_inverse:
return_args.append(inverse_map)

if len(return_args) == 1:
return return_args[0]
else:
if use_feat:
return (
discrete_coordinates[unique_map],
features[unique_map],
labels[unique_map],
)
else:
return discrete_coordinates[unique_map], labels[unique_map]
return tuple(return_args)
else:
tensor_stride = [1 for i in range(discrete_coordinates.shape[1] - 1)]
discrete_coordinates = (
Expand All @@ -319,20 +316,18 @@ def sparse_quantize(
else:
return unique_map

return_args = [discrete_coordinates[unique_map]]
if use_feat:
return_args.append(features[unique_map])
if return_index:
if return_inverse:
return discrete_coordinates[unique_map], unique_map, inverse_map
else:
return discrete_coordinates[unique_map], unique_map
return_args.append(unique_map)
if return_inverse:
return_args.append(inverse_map)

if len(return_args) == 1:
return return_args[0]
else:
if use_feat:
if device == "cuda":
assert isinstance(
features, torch.Tensor
), "For device==cuda, feature must be a torch Tensor"
return discrete_coordinates[unique_map], features[unique_map]
else:
return discrete_coordinates[unique_map]
return tuple(return_args)


def unique_coordinate_map(
Expand Down
6 changes: 5 additions & 1 deletion src/global_pooling_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> GlobalPoolingForwardGPU(
// pooling_mode = in_feat.size(0) / batch_size > 100 ? 1 : 2;

// origin kernel map
if (pooling_mode == PoolingMode::GLOBAL_SUM_POOLING_KERNEL ||
if (pooling_mode == PoolingMode::GLOBAL_SUM_POOLING_DEFAULT ||
pooling_mode == PoolingMode::GLOBAL_AVG_POOLING_DEFAULT ||
pooling_mode == PoolingMode::GLOBAL_SUM_POOLING_KERNEL ||
pooling_mode == PoolingMode::GLOBAL_AVG_POOLING_KERNEL ||
pooling_mode == PoolingMode::GLOBAL_SUM_POOLING_PYTORCH_INDEX ||
pooling_mode == PoolingMode::GLOBAL_AVG_POOLING_PYTORCH_INDEX) {
Expand All @@ -127,6 +129,8 @@ std::tuple<at::Tensor, at::Tensor> GlobalPoolingForwardGPU(
num_nonzero[b] = vec_maps[b].numel();
}
} break;
case PoolingMode::GLOBAL_SUM_POOLING_DEFAULT:
case PoolingMode::GLOBAL_AVG_POOLING_DEFAULT:
case PoolingMode::GLOBAL_SUM_POOLING_KERNEL:
case PoolingMode::GLOBAL_AVG_POOLING_KERNEL: {
const auto &in_outs = p_map_manager->origin_map(p_in_map_key);
Expand Down
35 changes: 7 additions & 28 deletions tests/python/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,13 @@ def test_mapping(self):
print("N unique:", len(mapping), "N:", N)
self.assertTrue((coords == coords[mapping[inverse_mapping]]).all())

index, reverse_index = sparse_quantize(
unique_coords, index, reverse_index = sparse_quantize(
coords, return_index=True, return_inverse=True
)
self.assertTrue((coords == coords[mapping[inverse_mapping]]).all())
self.assertTrue((coords == coords[index[reverse_index]]).all())

def test_label(self):
N = 16575
ignore_label = 255

coords = (np.random.rand(N, 3) * 100).astype(np.int32)
feats = np.random.rand(N, 4)
Expand All @@ -97,17 +96,9 @@ def test_label(self):
coords[:3] = 0
labels[:3] = 2

mapping, colabels = MEB.quantize_label_np(coords, labels, ignore_label)
print("Unique labels and counts:", np.unique(colabels, return_counts=True))
print("N unique:", len(mapping), "N:", N)

mapping, colabels = MEB.quantize_label_th(
torch.from_numpy(coords), torch.from_numpy(labels), ignore_label
qcoords, qfeats, qlabels, mapping, inverse_mapping = sparse_quantize(
coords, feats, labels, return_index=True, return_inverse=True
)
print("Unique labels and counts:", np.unique(colabels, return_counts=True))
print("N unique:", len(mapping), "N:", N)

qcoords, qfeats, qlabels = sparse_quantize(coords, feats, labels, ignore_label)
self.assertTrue(len(mapping) == len(qcoords))

def test_collision(self):
Expand All @@ -118,27 +109,15 @@ def test_collision(self):
coords, labels=labels, ignore_label=255
)
self.assertTrue(len(unique_coords) == 2)
self.assertTrue([0, 0] in unique_coords)
self.assertTrue([0, 1] in unique_coords)
self.assertTrue(torch.IntTensor([0, 0]) in unique_coords)
self.assertTrue(torch.IntTensor([0, 1]) in unique_coords)
self.assertTrue(len(colabels) == 2)
self.assertTrue(255 in colabels)

coords = np.array([[0, 0], [0, 1]], dtype=np.int32)
discrete_coords = sparse_quantize(coords)
self.assertTrue((discrete_coords == unique_coords).all())
discrete_coords = sparse_quantize(torch.from_numpy(coords))
self.assertTrue((discrete_coords == torch.from_numpy(unique_coords)).all())

def test_feature_average(self):
coords = torch.IntTensor([[0, 0], [0, 0], [0, 0], [0, 1]])
feats = torch.FloatTensor([[0, 1, 2, 3]]).t()
mapping, inverse_mapping = MEB.quantize_th(coords)
# inverse_mapping is the output map , range is the out map
avg_feat = MEB.quantization_average_features(
feats, torch.arange(len(feats)), inverse_mapping, len(mapping), 0
)
self.assertTrue(1 in avg_feat)
self.assertTrue(3 in avg_feat)
self.assertTrue((discrete_coords == unique_coords).all())

def test_quantization_size(self):
coords = torch.randn((1000, 3), dtype=torch.float)
Expand Down

0 comments on commit 84c52b7

Please sign in to comment.