Skip to content

Commit

Permalink
Global pooling policy
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit 48b747b722efe077ead9e952e8f9f13d3019c359
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Mon Dec 23 15:03:42 2019 -0800

    fix norm

commit ac616e5dcf869ac060adcf1ba375c9ccbd3c9851
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Mon Dec 23 14:41:56 2019 -0800

    Fix pruning, broadcast

commit 71b7f47fdef75a7c50e11d4d5a65184bfb04c30e
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Mon Dec 23 08:48:43 2019 -0800

    GlobalPooling done

commit 382ea9979d3602940f8927e8f25b937e879b8ede
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sun Dec 22 23:05:08 2019 -0800

    auto policy

commit 4bd38889a4e71d76f7e8766ec0218936b160dd38
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sun Dec 22 19:00:22 2019 -0800

    Global pooling mode

commit adfdda80a115f036265caa081de0ed29c739fa69
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sat Dec 21 20:34:36 2019 -0800

    pooling modes

commit ab0dc88bb9c7687c743848e19fc1535c26ac3323
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sat Dec 21 19:46:38 2019 -0800

    global with vectorized tensor return

commit 8c5ccab71f1151661b44da44396e8007e9a4a2eb
Author: Chris Choy <chrischoy@ai.stanford.edu>
Date:   Sat Dec 21 17:41:33 2019 -0800

    separate in out maps for braodcast/pool
  • Loading branch information
chrischoy committed Dec 23, 2019
1 parent cc98754 commit 64518c1
Show file tree
Hide file tree
Showing 32 changed files with 833 additions and 488 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
- Updated MinkowskiUnion, MinkowskiPruning docs
- Use cudaMalloc instead of `at::Tensor` for GPU memory management for illegal memory access, invalid arg.
- Region hypercube iterator with even numbered kernel
- Fix global reduction in-out map with non contiguous batch indices
- GlobalPooling with torch reduction
- GlobalPoolingMode with index select and sparse backbone
- If batch size == 1, skip the backend
- Added CoordsManager functions
- `get_batch_size`
- `get_batch_indices`
- `set_origin_coords_key`
- Updated CoordsManager function `get_row_indices_per_batch` to return a list of `torch.LongTensor` for mapping indices. The corresponding batch indices is accessible by `get_batch_indices`.
- Update `MinkowskiBroadcast`, `MinkowskiBroadcastConcatenation` to use row indices per batch (`getRowIndicesPerBatch`)


## [0.3.1] - 2019-12-15
Expand Down
18 changes: 18 additions & 0 deletions MinkowskiEngine/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ def get_postfix(tensor):
return postfix


class GlobalPoolingMode(Enum):
"""
Define the global pooling mode
"""
AUTO = 0, 'AUTO'
INDEX_SELECT = 1, 'INDEX_SELECT'
SPARSE = 2, 'SPARSE'

def __new__(cls, value, name):
member = object.__new__(cls)
member._value_ = value
member.fullname = name
return member

def __int__(self):
return self.value


class RegionType(Enum):
"""
Define the kernel region type
Expand Down
33 changes: 14 additions & 19 deletions MinkowskiEngine/MinkowskiBroadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ def forward(ctx, input_features, input_features_global, operation_type,
ctx.glob_coords_key = glob_coords_key
ctx.coords_manager = coords_manager

out_feat = input_features.new()

fw_fn = getattr(MEB, 'BroadcastForward' + get_postfix(input_features))
fw_fn(ctx.in_feat, ctx.in_feat_glob, out_feat, ctx.op,
ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager)
out_feat = fw_fn(ctx.in_feat, ctx.in_feat_glob, ctx.op,
ctx.in_coords_key.CPPCoordsKey,
ctx.glob_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager)
return out_feat

@staticmethod
Expand Down Expand Up @@ -227,15 +226,13 @@ def forward(self, input, input_glob):
assert isinstance(input_glob, SparseTensor)
assert input.D == self.dimension

coo = input.coords_man.get_coo_broadcast_coords(input.coords_key)
perm_mat = torch.sparse_coo_tensor(
coo,
torch.ones(len(input), dtype=input.dtype, device=input.device),
requires_grad=False)
broadcast_feat = torch.empty_like(input.F)
row_inds = input.coords_man.get_row_indices_per_batch(input.coords_key)
for b, row_ind in enumerate(row_inds):
broadcast_feat[row_ind] = input_glob.F[b]

broadcasted_input_glob = perm_mat.mm(input_glob.F)
return SparseTensor(
broadcasted_input_glob,
broadcast_feat,
coords_key=input.coords_key,
coords_manager=input.coords_man)

Expand Down Expand Up @@ -266,14 +263,12 @@ def forward(self, input, input_glob):
assert isinstance(input_glob, SparseTensor)
assert input.D == self.dimension

coo = input.coords_man.get_coo_broadcast_coords(input.coords_key)
perm_mat = torch.sparse_coo_tensor(
coo,
torch.ones(len(input), dtype=input.dtype, device=input.device),
requires_grad=False)
broadcast_feat = torch.empty_like(input.F)
row_inds = input.coords_man.get_row_indices_per_batch(input.coords_key)
for b, row_ind in enumerate(row_inds):
broadcast_feat[row_ind] = input_glob.F[b]

broadcasted_input_glob = perm_mat.mm(input_glob.F)
broadcast_cat = torch.cat((input.F, broadcasted_input_glob), dim=1)
broadcast_cat = torch.cat((input.F, broadcast_feat), dim=1)
return SparseTensor(
broadcast_cat,
coords_key=input.coords_key,
Expand Down
39 changes: 17 additions & 22 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,31 @@ def get_coords(self, coords_key):
self.CPPCoordsManager.getCoords(coords, coords_key.CPPCoordsKey)
return coords

def get_row_indices_per_batch(self, coords_key):
r"""
return a list of unique batch indices, and a list of lists of row indices per batch.
def get_batch_size(self):
return self.CPPCoordsManager.getBatchSize()

def get_batch_indices(self):
return self.CPPCoordsManager.getBatchIndices()

def set_origin_coords_key(self, coords_key):
self.CPPCoordsManager.setOriginCoordsKey(coords_key.CPPCoordsKey)

def get_row_indices_per_batch(self, coords_key, out_coords_key=None):
r"""Return a list of lists of row indices per batch.
The corresponding batch indices are accessible by `get_batch_indices`.
.. code-block:: python
sp_tensor = ME.SparseTensor(features, coords=coordinates)
batch_indices, list_of_row_indices = sp_tensor.coords_man.get_row_indices_per_batch(sp_tensor.coords_key)
batch_indices = sp_tensor.coords_man.get_row_indices_per_batch(sp_tensor.coords_key)
"""
assert isinstance(coords_key, CoordsKey)
out_key = CoordsKey(self.D)
if out_coords_key is None:
out_coords_key = CoordsKey(self.D)
return self.CPPCoordsManager.getRowIndicesPerBatch(
coords_key.CPPCoordsKey, out_key.CPPCoordsKey)

def get_coo_broadcast_coords(self, coords_key, transpose=False):
_, list_of_row_indices = self.get_row_indices_per_batch(coords_key)
coos = []
for batch_ind, row_inds in enumerate(list_of_row_indices):
if transpose:
coo = torch.LongTensor([[
batch_ind,
] * len(row_inds), row_inds])
else:
coo = torch.LongTensor([row_inds, [
batch_ind,
] * len(row_inds)])
coos.append(coo)

return torch.cat(coos, dim=1)
coords_key.CPPCoordsKey, out_coords_key.CPPCoordsKey)

def get_kernel_map(self,
in_tensor_strides,
Expand Down
81 changes: 44 additions & 37 deletions MinkowskiEngine/MinkowskiNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from MinkowskiBroadcast import MinkowskiBroadcastAddition, MinkowskiBroadcastMultiplication, OperationType, operation_type_to_int
import MinkowskiEngineBackend as MEB
from MinkowskiCoords import CoordsKey
from Common import get_postfix
from Common import get_postfix, GlobalPoolingMode


class MinkowskiBatchNorm(Module):
Expand Down Expand Up @@ -152,9 +152,12 @@ class MinkowskiInstanceNormFunction(Function):
@staticmethod
def forward(ctx,
in_feat,
mode=GlobalPoolingMode.AUTO,
in_coords_key=None,
glob_coords_key=None,
coords_manager=None):
assert isinstance(mode, GlobalPoolingMode), \
f"Mode must be an instance of GlobalPoolingMode, {mode}"
if glob_coords_key is None:
glob_coords_key = CoordsKey(in_coords_key.D)

Expand All @@ -172,26 +175,29 @@ def forward(ctx,
cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
cpp_coords_manager = coords_manager.CPPCoordsManager

gpool_forward(in_feat, mean, num_nonzero, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager, True)
mean, num_nonzero = gpool_forward(in_feat, cpp_in_coords_key,
cpp_glob_coords_key,
cpp_coords_manager, True, mode.value)
# X - \mu
centered_feat = in_feat.new()
broadcast_forward(in_feat, -mean, centered_feat, add, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager)
centered_feat = broadcast_forward(in_feat, -mean, add,
cpp_in_coords_key,
cpp_glob_coords_key,
cpp_coords_manager)

# Variance = 1/N \sum (X - \mu) ** 2
variance = in_feat.new()
gpool_forward(centered_feat**2, variance, num_nonzero,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager, True)
variance, num_nonzero = gpool_forward(centered_feat**2,
cpp_in_coords_key,
cpp_glob_coords_key,
cpp_coords_manager, True,
mode.value)

# norm_feat = (X - \mu) / \sigma
inv_std = 1 / (variance + 1e-8).sqrt()
norm_feat = in_feat.new()
broadcast_forward(centered_feat, inv_std, norm_feat, multiply,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)
norm_feat = broadcast_forward(centered_feat, inv_std, multiply,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)

ctx.mode = mode
ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key
ctx.coords_manager = coords_manager
# For GPU tensors, must use save_for_backward.
Expand Down Expand Up @@ -219,32 +225,32 @@ def backward(ctx, out_grad):
cpp_coords_manager = coords_manager.CPPCoordsManager

# 1/N \sum dout
num_nonzero = out_grad.new()
mean_dout = out_grad.new()
gpool_forward(out_grad, mean_dout, num_nonzero, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager, True)
mean_dout, num_nonzero = gpool_forward(out_grad,
cpp_in_coords_key,
cpp_glob_coords_key,
cpp_coords_manager, True,
ctx.mode.value)

# 1/N \sum (dout * out)
mean_dout_feat = out_grad.new()
gpool_forward(out_grad * norm_feat, mean_dout_feat, num_nonzero,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager, True)
mean_dout_feat, num_nonzero = gpool_forward(out_grad * norm_feat,
cpp_in_coords_key,
cpp_glob_coords_key,
cpp_coords_manager, True,
ctx.mode.value)

# out * 1/N \sum (dout * out)
feat_mean_dout_feat = out_grad.new()
broadcast_forward(norm_feat, mean_dout_feat, feat_mean_dout_feat,
multiply, cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)
feat_mean_dout_feat = broadcast_forward(norm_feat, mean_dout_feat,
multiply, cpp_in_coords_key,
cpp_glob_coords_key,
cpp_coords_manager)

unnorm_din = out_grad.new()
broadcast_forward(out_grad - feat_mean_dout_feat, -mean_dout,
unnorm_din, add, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager)
unnorm_din = broadcast_forward(out_grad - feat_mean_dout_feat,
-mean_dout, add, cpp_in_coords_key,
cpp_glob_coords_key, cpp_coords_manager)

norm_din = out_grad.new()
broadcast_forward(unnorm_din, inv_std, norm_din, multiply,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)
norm_din = broadcast_forward(unnorm_din, inv_std, multiply,
cpp_in_coords_key, cpp_glob_coords_key,
cpp_coords_manager)

return norm_din, None, None, None, None

Expand Down Expand Up @@ -301,7 +307,7 @@ class MinkowskiInstanceNorm(Module):
"""

def __init__(self, num_features, dimension=-1):
def __init__(self, num_features, mode=GlobalPoolingMode.AUTO, dimension=-1):
r"""
Args:
Expand All @@ -316,6 +322,7 @@ def __init__(self, num_features, dimension=-1):
self.bias = nn.Parameter(torch.zeros(1, num_features))
self.dimension = dimension
self.reset_parameters()
self.mode = mode
self.inst_norm = MinkowskiInstanceNormFunction()

def __repr__(self):
Expand All @@ -330,8 +337,8 @@ def forward(self, input):
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

output = self.inst_norm.apply(input.F, input.coords_key, None,
input.coords_man)
output = self.inst_norm.apply(input.F, self.mode, input.coords_key,
None, input.coords_man)
output = output * self.weight + self.bias

return SparseTensor(
Expand Down
39 changes: 26 additions & 13 deletions MinkowskiEngine/MinkowskiPooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

import MinkowskiEngineBackend as MEB
from SparseTensor import SparseTensor
from Common import KernelGenerator, RegionType, MinkowskiModuleBase, \
from Common import KernelGenerator, RegionType, GlobalPoolingMode, \
MinkowskiModuleBase, \
convert_to_int_list, convert_to_int_tensor, \
prep_args, save_ctx, get_postfix
from MinkowskiCoords import CoordsKey
Expand Down Expand Up @@ -678,35 +679,43 @@ class MinkowskiGlobalPoolingFunction(Function):
def forward(ctx,
input_features,
average=True,
mode=GlobalPoolingMode.AUTO,
in_coords_key=None,
out_coords_key=None,
coords_manager=None):
if out_coords_key is None:
out_coords_key = CoordsKey(in_coords_key.D)
assert isinstance(mode, GlobalPoolingMode), \
f"Mode must be an instance of GlobalPoolingMode, {mode}"

ctx.in_coords_key = in_coords_key
ctx.out_coords_key = out_coords_key

ctx.in_feat = input_features
out_feat = input_features.new()
ctx.average = average
ctx.num_nonzero = input_features.new()
ctx.coords_manager = coords_manager
ctx.mode = mode.value

fw_fn = getattr(MEB,
'GlobalPoolingForward' + get_postfix(input_features))
fw_fn(ctx.in_feat, out_feat, ctx.num_nonzero,
ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager, ctx.average)
out_feat, num_nonzero = fw_fn(ctx.in_feat,
ctx.in_coords_key.CPPCoordsKey,
ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager,
ctx.average, ctx.mode)

ctx.num_nonzero = num_nonzero

return out_feat

@staticmethod
def backward(ctx, grad_out_feat):
grad_in_feat = grad_out_feat.new()
bw_fn = getattr(MEB,
'GlobalPoolingBackward' + get_postfix(grad_out_feat))
bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager, ctx.average)
grad_in_feat = bw_fn(ctx.in_feat, grad_out_feat, ctx.num_nonzero,
ctx.in_coords_key.CPPCoordsKey,
ctx.out_coords_key.CPPCoordsKey,
ctx.coords_manager.CPPCoordsManager, ctx.average)
return grad_in_feat, None, None, None, None, None


Expand All @@ -720,7 +729,7 @@ class MinkowskiGlobalPooling(MinkowskiModuleBase):
"""

def __init__(self, average=True, dimension=-1):
def __init__(self, average=True, mode=GlobalPoolingMode.AUTO, dimension=-1):
r"""Reduces sparse coords into points at origin, i.e. reduce each point
cloud into a point at the origin, returning batch_size number of points
[[0, 0, ..., 0], [0, 0, ..., 1],, [0, 0, ..., 2]] where the last elem
Expand All @@ -737,7 +746,10 @@ def __init__(self, average=True, dimension=-1):
"""
super(MinkowskiGlobalPooling, self).__init__()
assert dimension > 0, f"dimension must be a positive integer, {dimension}"
assert isinstance(mode, GlobalPoolingMode), \
f"Mode must be an instance of GlobalPoolingMode. mode={mode}"

self.mode = mode
self.average = average
self.dimension = dimension
self.pooling = MinkowskiGlobalPoolingFunction()
Expand All @@ -747,8 +759,9 @@ def forward(self, input):
assert input.D == self.dimension

out_coords_key = CoordsKey(input.coords_key.D)
output = self.pooling.apply(input.F, self.average, input.coords_key,
out_coords_key, input.coords_man)
output = self.pooling.apply(input.F, self.average, self.mode,
input.coords_key, out_coords_key,
input.coords_man)

return SparseTensor(
output, coords_key=out_coords_key, coords_manager=input.coords_man)
Expand Down
Loading

0 comments on commit 64518c1

Please sign in to comment.