Skip to content

Commit

Permalink
Vectorized coords (#87)
Browse files Browse the repository at this point in the history
* WIP: PyCoordKey with vec coords

* WIP: Compilation successful

* tests.coords

* fix region_iter, CPU_ONLY build

* Fix pruning, pooling

* Analytics for docs
  • Loading branch information
chrischoy authored Sep 4, 2019
1 parent 1d9258f commit 1ece298
Show file tree
Hide file tree
Showing 37 changed files with 1,343 additions and 1,691 deletions.
19 changes: 9 additions & 10 deletions MinkowskiEngine/MinkowskiBroadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def forward(ctx, input_features, input_features_global, operation_type,
out_feat = input_features.new()

fw_fn = getattr(MEB, 'BroadcastForward' + get_postfix(input_features))
fw_fn(ctx.in_coords_key.D, ctx.in_feat, ctx.in_feat_glob, out_feat,
ctx.op, ctx.in_coords_key.CPPCoordsKey,
ctx.glob_coords_key.CPPCoordsKey,
fw_fn(ctx.in_feat, ctx.in_feat_glob, out_feat, ctx.op,
ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager)
return out_feat

Expand All @@ -83,9 +82,9 @@ def backward(ctx, grad_out_feat):
grad_in_feat = grad_out_feat.new()
grad_in_feat_glob = grad_out_feat.new()
bw_fn = getattr(MEB, 'BroadcastBackward' + get_postfix(grad_out_feat))
bw_fn(ctx.in_coords_key.D, ctx.in_feat, grad_in_feat, ctx.in_feat_glob,
grad_in_feat_glob, grad_out_feat, ctx.op,
ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey,
bw_fn(ctx.in_feat, grad_in_feat, ctx.in_feat_glob, grad_in_feat_glob,
grad_out_feat, ctx.op, ctx.in_coords_key.CPPCoordsKey,
ctx.glob_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager)
return grad_in_feat, grad_in_feat_glob, None, None, None, None

Expand Down Expand Up @@ -147,8 +146,8 @@ def __init__(self, dimension=-1):
space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiBroadcastAddition, self).__init__(
OperationType.ADDITION, dimension)
super(MinkowskiBroadcastAddition,
self).__init__(OperationType.ADDITION, dimension)


class MinkowskiBroadcastMultiplication(AbstractMinkowskiBroadcast):
Expand Down Expand Up @@ -180,8 +179,8 @@ def __init__(self, dimension=-1):
space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiBroadcastMultiplication, self).__init__(
OperationType.MULTIPLICATION, dimension)
super(MinkowskiBroadcastMultiplication,
self).__init__(OperationType.MULTIPLICATION, dimension)


class MinkowskiBroadcast(Module):
Expand Down
18 changes: 10 additions & 8 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def forward(ctx,
# Prep arguments
# Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat)
assert input_features.shape[1] == kernel.shape[1], \
"The input shape " + str(list(input_features.shape)) + " does not match the kernel shape " + str(list(kernel.shape))
"The input shape " + str(list(input_features.shape)) + \
" does not match the kernel shape " + str(list(kernel.shape))
if out_coords_key is None:
out_coords_key = CoordsKey(in_coords_key.D)
assert in_coords_key.D == out_coords_key.D
Expand All @@ -81,7 +82,7 @@ def forward(ctx,
out_feat = input_features.new()

fw_fn = getattr(MEB, 'ConvolutionForward' + get_postfix(input_features))
fw_fn(D, ctx.in_feat, out_feat, kernel,
fw_fn(ctx.in_feat, out_feat, kernel,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand All @@ -99,8 +100,8 @@ def backward(ctx, grad_out_feat):
grad_kernel = grad_out_feat.new()
D = ctx.in_coords_key.D
bw_fn = getattr(MEB, 'ConvolutionBackward' + get_postfix(grad_out_feat))
bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel,
grad_kernel, convert_to_int_list(ctx.tensor_stride, D),
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel, grad_kernel,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
convert_to_int_list(ctx.dilation, D), ctx.region_type,
Expand Down Expand Up @@ -130,7 +131,8 @@ def forward(ctx,
# Prep arguments
# Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat)
assert input_features.shape[1] == kernel.shape[1], \
"The input shape " + str(list(input_features.shape)) + " does not match the kernel shape " + str(list(kernel.shape))
"The input shape " + str(list(input_features.shape)) + \
" does not match the kernel shape " + str(list(kernel.shape))
if out_coords_key is None:
out_coords_key = CoordsKey(in_coords_key.D)
assert in_coords_key.D == out_coords_key.D
Expand All @@ -156,7 +158,7 @@ def forward(ctx,

fw_fn = getattr(
MEB, 'ConvolutionTransposeForward' + get_postfix(input_features))
fw_fn(D, ctx.in_feat, out_feat, kernel,
fw_fn(ctx.in_feat, out_feat, kernel,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand All @@ -175,8 +177,8 @@ def backward(ctx, grad_out_feat):
D = ctx.in_coords_key.D
bw_fn = getattr(
MEB, 'ConvolutionTransposeBackward' + get_postfix(grad_out_feat))
bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel,
grad_kernel, convert_to_int_list(ctx.tensor_stride, D),
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel, grad_kernel,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
convert_to_int_list(ctx.dilation, D), ctx.region_type,
Expand Down
5 changes: 3 additions & 2 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class CoordsKey():

def __init__(self, D):
self.D = D
self.CPPCoordsKey = getattr(MEB, f'PyCoordsKey{self.D}')()
self.CPPCoordsKey = getattr(MEB, f'PyCoordsKey')()
self.CPPCoordsKey.setDimension(D)

def setKey(self, key):
self.CPPCoordsKey.setKey(key)
Expand All @@ -56,7 +57,7 @@ def __init__(self, D=-1):
if D < 1:
raise ValueError(f"Invalid dimension {D}")
self.D = D
CPPCoordsManager = getattr(MEB, f'PyCoordsManager{D}int32')
CPPCoordsManager = getattr(MEB, f'PyCoordsManagerint32')
coords_man = CPPCoordsManager()
self.CPPCoordsManager = coords_man

Expand Down
23 changes: 10 additions & 13 deletions MinkowskiEngine/MinkowskiNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,30 +88,28 @@ def forward(ctx,

mean = in_feat.new()
num_nonzero = in_feat.new()
D = in_coords_key.D

cpp_in_coords_key = in_coords_key.CPPCoordsKey
cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
cpp_coords_manager = coords_manager.CPPCoordsManager

gpool_forward(D, in_feat, mean, num_nonzero, cpp_in_coords_key,
gpool_forward(in_feat, mean, num_nonzero, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager, True)
# X - \mu
centered_feat = in_feat.new()
broadcast_forward(D, in_feat, -mean, centered_feat, add,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)
broadcast_forward(in_feat, -mean, centered_feat, add, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager)

# Variance = 1/N \sum (X - \mu) ** 2
variance = in_feat.new()
gpool_forward(D, centered_feat**2, variance, num_nonzero,
gpool_forward(centered_feat**2, variance, num_nonzero,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager, True)

# norm_feat = (X - \mu) / \sigma
inv_std = 1 / (variance + 1e-8).sqrt()
norm_feat = in_feat.new()
broadcast_forward(D, centered_feat, inv_std, norm_feat, multiply,
broadcast_forward(centered_feat, inv_std, norm_feat, multiply,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)

Expand All @@ -129,7 +127,6 @@ def backward(ctx, out_grad):

# To prevent the memory leakage, compute the norm again
inv_std, norm_feat = ctx.saved_variables
D = in_coords_key.D

gpool_forward = getattr(MEB,
'GlobalPoolingForward' + get_postfix(out_grad))
Expand All @@ -145,28 +142,28 @@ def backward(ctx, out_grad):
# 1/N \sum dout
num_nonzero = out_grad.new()
mean_dout = out_grad.new()
gpool_forward(D, out_grad, mean_dout, num_nonzero, cpp_in_coords_key,
gpool_forward(out_grad, mean_dout, num_nonzero, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager, True)

# 1/N \sum (dout * out)
mean_dout_feat = out_grad.new()
gpool_forward(D, out_grad * norm_feat, mean_dout_feat, num_nonzero,
gpool_forward(out_grad * norm_feat, mean_dout_feat, num_nonzero,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager, True)

# out * 1/N \sum (dout * out)
feat_mean_dout_feat = out_grad.new()
broadcast_forward(D, norm_feat, mean_dout_feat, feat_mean_dout_feat,
broadcast_forward(norm_feat, mean_dout_feat, feat_mean_dout_feat,
multiply, cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)

unnorm_din = out_grad.new()
broadcast_forward(D, out_grad - feat_mean_dout_feat, -mean_dout,
broadcast_forward(out_grad - feat_mean_dout_feat, -mean_dout,
unnorm_din, add, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager)

norm_din = out_grad.new()
broadcast_forward(D, unnorm_din, inv_std, norm_din, multiply,
broadcast_forward(unnorm_din, inv_std, norm_din, multiply,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)

Expand Down
25 changes: 11 additions & 14 deletions MinkowskiEngine/MinkowskiPooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(ctx,
ctx.max_index = max_index

fw_fn = getattr(MEB, 'MaxPoolingForward' + get_postfix(input_features))
fw_fn(D, input_features, out_feat, max_index,
fw_fn(input_features, out_feat, max_index,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand All @@ -89,7 +89,7 @@ def backward(ctx, grad_out_feat):
grad_in_feat = grad_out_feat.new()
D = ctx.in_coords_key.D
bw_fn = getattr(MEB, 'MaxPoolingBackward' + get_postfix(grad_out_feat))
bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index,
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand Down Expand Up @@ -146,7 +146,7 @@ def forward(ctx,
ctx.num_nonzero = input_features.new()

fw_fn = getattr(MEB, 'AvgPoolingForward' + get_postfix(input_features))
fw_fn(D, ctx.in_feat, out_feat, ctx.num_nonzero,
fw_fn(ctx.in_feat, out_feat, ctx.num_nonzero,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand All @@ -163,7 +163,7 @@ def backward(ctx, grad_out_feat):
grad_in_feat = grad_out_feat.new()
D = ctx.in_coords_key.D
bw_fn = getattr(MEB, 'AvgPoolingBackward' + get_postfix(grad_out_feat))
bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand Down Expand Up @@ -517,7 +517,7 @@ def forward(ctx,
D = in_coords_key.D
fw_fn = getattr(MEB,
'PoolingTransposeForward' + get_postfix(input_features))
fw_fn(in_coords_key.D, ctx.in_feat, out_feat, ctx.num_nonzero,
fw_fn(ctx.in_feat, out_feat, ctx.num_nonzero,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
Expand All @@ -532,8 +532,8 @@ def backward(ctx, grad_out_feat):
D = ctx.in_coords_key.D
bw_fn = getattr(MEB,
'PoolingTransposeBackward' + get_postfix(grad_out_feat))
bw_fn(ctx.in_coords_key.D, ctx.in_feat, grad_in_feat, grad_out_feat,
ctx.num_nonzero, convert_to_int_list(ctx.tensor_stride, D),
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
convert_to_int_list(ctx.dilation, D), ctx.region_type,
Expand Down Expand Up @@ -644,21 +644,19 @@ def forward(ctx,
ctx.num_nonzero = input_features.new()
ctx.coords_manager = coords_manager

D = in_coords_key.D
fw_fn = getattr(MEB,
'GlobalPoolingForward' + get_postfix(input_features))
fw_fn(D, ctx.in_feat, out_feat, ctx.num_nonzero,
fw_fn(ctx.in_feat, out_feat, ctx.num_nonzero,
ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager, ctx.average)
return out_feat

@staticmethod
def backward(ctx, grad_out_feat):
grad_in_feat = grad_out_feat.new()
D = ctx.in_coords_key.D
bw_fn = getattr(MEB,
'GlobalPoolingBackward' + get_postfix(grad_out_feat))
bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager, ctx.average)
return grad_in_feat, None, None, None, None, None
Expand Down Expand Up @@ -701,9 +699,8 @@ def forward(self, input):
assert input.D == self.dimension

out_coords_key = CoordsKey(input.coords_key.D)
output = self.pooling.apply(input.F, self.average,
input.coords_key, out_coords_key,
input.coords_man)
output = self.pooling.apply(input.F, self.average, input.coords_key,
out_coords_key, input.coords_man)

return SparseTensor(
output, coords_key=out_coords_key, coords_manager=input.coords_man)
Expand Down
13 changes: 8 additions & 5 deletions MinkowskiEngine/MinkowskiPruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ class MinkowskiPruningFunction(Function):
def forward(ctx, in_feat, use_feat, in_coords_key, out_coords_key,
coords_manager):
assert in_feat.size(0) == use_feat.size(0)
assert isinstance(use_feat, torch.ByteTensor)
assert isinstance(use_feat, torch.ByteTensor) \
or isinstance(use_feat, torch.BoolTensor), "use_feat must be a bool/byte tensor."
if isinstance(use_feat, torch.BoolTensor):
use_feat = use_feat.byte()
if not in_feat.is_contiguous():
in_feat = in_feat.contiguous()
if not use_feat.is_contiguous():
Expand All @@ -50,8 +53,8 @@ def forward(ctx, in_feat, use_feat, in_coords_key, out_coords_key,
out_feat = in_feat.new()

fw_fn = getattr(MEB, 'PruningForward' + get_postfix(in_feat))
fw_fn(ctx.in_coords_key.D, in_feat, out_feat, use_feat,
ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
fw_fn(in_feat, out_feat, use_feat, ctx.in_coords_key.CPPCoordsKey,
ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager)
return out_feat

Expand All @@ -62,8 +65,8 @@ def backward(ctx, grad_out_feat):

grad_in_feat = grad_out_feat.new()
bw_fn = getattr(MEB, 'PruningBackward' + get_postfix(grad_out_feat))
bw_fn(ctx.in_coords_key.D, grad_in_feat, grad_out_feat,
ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
bw_fn(grad_in_feat, grad_out_feat, ctx.in_coords_key.CPPCoordsKey,
ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager)
return grad_in_feat, None, None, None, None, None

Expand Down
2 changes: 1 addition & 1 deletion MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
__version__ = "0.2.6"
__version__ = "0.2.7"

import os
import sys
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ The MinkowskiEngine is an auto-differentiation library for sparse tensors. It su

## Features

- Unlimited high-(spatial)-dimensional sparse tensor support
- Dynamic computation graph
- Custom kernel shapes
- [Generalized sparse convolution](https://stanfordvl.github.io/MinkowskiEngine/generalized_sparse_conv.html)
Expand All @@ -18,7 +19,7 @@ The MinkowskiEngine is an auto-differentiation library for sparse tensors. It su

- Ubuntu 14.04 or higher
- CUDA 10.0 or higher
- pytorch 1.1 or higher
- pytorch 1.2 or higher


## Installation
Expand Down
10 changes: 10 additions & 0 deletions docs/_templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,14 @@
<a href="https://github.com/StanfordVL/MinkowskiEngine">
<img style="position: absolute; top: 0; right: 0; border: 0;" src="https://s3.amazonaws.com/github/ribbons/forkme_right_darkblue_121621.png" alt="Fork me on GitHub">
</a>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-43980256-3"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());

gtag('config', 'UA-43980256-3');
</script>

{% endblock %}
3 changes: 2 additions & 1 deletion docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ The MinkowskiEngine is an auto-differentiation library for sparse tensors. It su

## Features

- Unlimited high-(spatial)-dimensional sparse tensor support
- Dynamic computation graph
- Custom kernel shapes
- [Generalized sparse convolution](https://stanfordvl.github.io/MinkowskiEngine/generalized_sparse_conv.html)
Expand All @@ -18,7 +19,7 @@ The MinkowskiEngine is an auto-differentiation library for sparse tensors. It su

- Ubuntu 14.04 or higher
- CUDA 10.0 or higher
- pytorch 1.1 or higher
- pytorch 1.2 or higher


## Installation
Expand Down
Empty file added examples/__init__.py
Empty file.
Loading

0 comments on commit 1ece298

Please sign in to comment.