Skip to content

Commit

Permalink
global pool, broadcast without arg, minor
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 31, 2020
1 parent 5d8dbeb commit 6a932ca
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 122 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Change Log

## [master] - 2020-01-30

### Changed

- Fix dtype double to float on the multi-gpu example
- Remove the dimension input argument on GlobalPooling, Broadcast functions
- Kernel map generation has tensor stride > 0 check


## [0.4.1] - 2020-01-28

### Changed
Expand Down
47 changes: 5 additions & 42 deletions MinkowskiEngine/MinkowskiBroadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,16 @@ def backward(ctx, grad_out_feat):

class AbstractMinkowskiBroadcast(Module):

def __init__(self, operation_type, dimension=-1):
def __init__(self, operation_type):
super(AbstractMinkowskiBroadcast, self).__init__()
assert isinstance(operation_type, OperationType)
assert dimension > 0, f"dimension must be a positive integer, {dimension}"

self.operation_type = operation_type
self.dimension = dimension

self.broadcast = MinkowskiBroadcastFunction()

def forward(self, input, input_glob):
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

output = self.broadcast.apply(input.F, input_glob.F,
self.operation_type, input.coords_key,
Expand Down Expand Up @@ -136,17 +133,8 @@ class MinkowskiBroadcastAddition(AbstractMinkowskiBroadcast):
"""

def __init__(self, dimension=-1):
r"""a broadcast addition layer.
Args:
:attr:`dimension` (int): the dimension of the space where all the
inputs and the network is defined. For example, images are in a 2D
space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiBroadcastAddition,
self).__init__(OperationType.ADDITION, dimension)
def __init__(self):
AbstractMinkowskiBroadcast.__init__(self, OperationType.ADDITION)


class MinkowskiBroadcastMultiplication(AbstractMinkowskiBroadcast):
Expand All @@ -169,17 +157,8 @@ class MinkowskiBroadcastMultiplication(AbstractMinkowskiBroadcast):
"""

def __init__(self, dimension=-1):
r"""a broadcast multiplication layer.
Args:
:attr:`dimension` (int): the dimension of the space where all the
inputs and the network is defined. For example, images are in a 2D
space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiBroadcastMultiplication,
self).__init__(OperationType.MULTIPLICATION, dimension)
def __init__(self):
AbstractMinkowskiBroadcast.__init__(self, OperationType.MULTIPLICATION)


class MinkowskiBroadcast(Module):
Expand All @@ -204,27 +183,12 @@ class MinkowskiBroadcast(Module):
"""

def __init__(self, dimension=-1):
r"""broadcast layer.
Args:
:attr:`dimension` (int): the dimension of the space where all the
inputs and the network is defined. For example, images are in a 2D
space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiBroadcast, self).__init__()
assert dimension > 0, f"dimension must be a positive integer, {dimension}"

self.dimension = dimension

def __repr__(self):
return self.__class__.__name__

def forward(self, input, input_glob):
assert isinstance(input, SparseTensor)
assert isinstance(input_glob, SparseTensor)
assert input.D == self.dimension

broadcast_feat = torch.empty_like(input.F)
row_inds = input.coords_man.get_row_indices_per_batch(input.coords_key)
Expand Down Expand Up @@ -261,7 +225,6 @@ class MinkowskiBroadcastConcatenation(MinkowskiBroadcast):
def forward(self, input, input_glob):
assert isinstance(input, SparseTensor)
assert isinstance(input_glob, SparseTensor)
assert input.D == self.dimension

broadcast_feat = torch.empty_like(input.F)
row_inds = input.coords_man.get_row_indices_per_batch(input.coords_key)
Expand Down
10 changes: 5 additions & 5 deletions MinkowskiEngine/MinkowskiNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ def __init__(self, num_features, dimension=-1):
self.weight = nn.Parameter(torch.ones(1, num_features))
self.bias = nn.Parameter(torch.zeros(1, num_features))

self.mean_in = MinkowskiGlobalPooling(dimension=dimension)
self.glob_sum = MinkowskiBroadcastAddition(dimension=dimension)
self.glob_sum2 = MinkowskiBroadcastAddition(dimension=dimension)
self.glob_mean = MinkowskiGlobalPooling(dimension=dimension)
self.glob_times = MinkowskiBroadcastMultiplication(dimension=dimension)
self.mean_in = MinkowskiGlobalPooling()
self.glob_sum = MinkowskiBroadcastAddition()
self.glob_sum2 = MinkowskiBroadcastAddition()
self.glob_mean = MinkowskiGlobalPooling()
self.glob_times = MinkowskiBroadcastMultiplication()
self.dimension = dimension
self.reset_parameters()

Expand Down
41 changes: 6 additions & 35 deletions MinkowskiEngine/MinkowskiPooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ class MinkowskiGlobalPooling(MinkowskiModuleBase):
"""

def __init__(self, average=True, mode=GlobalPoolingMode.AUTO, dimension=-1):
def __init__(self, average=True, mode=GlobalPoolingMode.AUTO):
r"""Reduces sparse coords into points at origin, i.e. reduce each point
cloud into a point at the origin, returning batch_size number of points
[[0, 0, ..., 0], [0, 0, ..., 1],, [0, 0, ..., 2]] where the last elem
Expand All @@ -736,24 +736,17 @@ def __init__(self, average=True, mode=GlobalPoolingMode.AUTO, dimension=-1):
:attr:`average` (bool): when True, return the averaged output. If
not, return the sum of all input features.
:attr:`dimension` (int): the spatial dimension of the space where
all the inputs and the network are defined. For example, images are
in a 2D space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiGlobalPooling, self).__init__()
assert dimension > 0, f"dimension must be a positive integer, {dimension}"
assert isinstance(mode, GlobalPoolingMode), \
f"Mode must be an instance of GlobalPoolingMode. mode={mode}"

self.mode = mode
self.average = average
self.dimension = dimension
self.pooling = MinkowskiGlobalPoolingFunction()

def forward(self, input):
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

out_coords_key = CoordsKey(input.coords_key.D)
output = self.pooling.apply(input.F, self.average, self.mode,
Expand All @@ -769,38 +762,26 @@ def __repr__(self):

class MinkowskiGlobalSumPooling(MinkowskiGlobalPooling):

def __init__(self, mode=GlobalPoolingMode.AUTO, dimension=-1):
def __init__(self, mode=GlobalPoolingMode.AUTO):
r"""Reduces sparse coords into points at origin, i.e. reduce each point
cloud into a point at the origin, returning batch_size number of points
[[0, 0, ..., 0], [0, 0, ..., 1],, [0, 0, ..., 2]] where the last elem
of the coords is the batch index.
Args:
:attr:`dimension` (int): the spatial dimension of the space where
all the inputs and the network are defined. For example, images are
in a 2D space, meshes and 3D shapes are in a 3D space.
"""
MinkowskiGlobalPooling.__init__(
self, False, mode=mode, dimension=dimension)
MinkowskiGlobalPooling.__init__(self, False, mode=mode)


class MinkowskiGlobalAvgPooling(MinkowskiGlobalPooling):

def __init__(self, mode=GlobalPoolingMode.AUTO, dimension=-1):
def __init__(self, mode=GlobalPoolingMode.AUTO):
r"""Reduces sparse coords into points at origin, i.e. reduce each point
cloud into a point at the origin, returning batch_size number of points
[[0, 0, ..., 0], [0, 0, ..., 1],, [0, 0, ..., 2]] where the last elem
of the coords is the batch index.
Args:
:attr:`dimension` (int): the spatial dimension of the space where
all the inputs and the network are defined. For example, images are
in a 2D space, meshes and 3D shapes are in a 3D space.
"""
MinkowskiGlobalPooling.__init__(
self, True, mode=mode, dimension=dimension)
MinkowskiGlobalPooling.__init__(self, True, mode=mode)


class MinkowskiGlobalMaxPoolingFunction(Function):
Expand Down Expand Up @@ -852,28 +833,18 @@ class MinkowskiGlobalMaxPooling(MinkowskiModuleBase):
"""

def __init__(self, dimension=-1):
def __init__(self):
r"""Reduces sparse coords into points at origin, i.e. reduce each point
cloud into a point at the origin, returning batch_size number of points
[[0, 0, ..., 0], [0, 0, ..., 1],, [0, 0, ..., 2]] where the last elem
of the coords is the batch index.
Args:
:attr:`dimension` (int): the spatial dimension of the space where
all the inputs and the network are defined. For example, images are
in a 2D space, meshes and 3D shapes are in a 3D space.
"""
super(MinkowskiGlobalMaxPooling, self).__init__()
assert dimension > 0, f"dimension must be a positive integer, {dimension}"

self.dimension = dimension
self.pooling = MinkowskiGlobalMaxPoolingFunction()

def forward(self, input):
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

out_coords_key = CoordsKey(input.coords_key.D)
output = self.pooling.apply(input.F, input.coords_key, out_coords_key,
Expand Down
4 changes: 2 additions & 2 deletions MinkowskiEngine/modules/senet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self, channel, reduction=16, D=-1):
ME.MinkowskiReLU(inplace=True),
ME.MinkowskiLinear(channel // reduction, channel),
ME.MinkowskiSigmoid())
self.pooling = ME.MinkowskiGlobalPooling(dimension=D)
self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D)
self.pooling = ME.MinkowskiGlobalPooling()
self.broadcast_mul = ME.MinkowskiBroadcastMultiplication()

def forward(self, x):
y = self.pooling(x)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class ExampleNetwork(ME.MinkowskiNetwork):
dimension=D),
ME.MinkowskiBatchNorm(128),
ME.MinkowskiReLU())
self.pooling = ME.MinkowskiGlobalPooling(dimension=D)
self.pooling = ME.MinkowskiGlobalPooling()
self.linear = ME.MinkowskiLinear(128, out_feat)

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion docs/demo/interop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ a min-batch.
kernel_size=3,
stride=2,
dimension=D), ME.MinkowskiBatchNorm(128), ME.MinkowskiReLU(),
ME.MinkowskiGlobalPooling(dimension=D),
ME.MinkowskiGlobalPooling(),
ME.MinkowskiLinear(128, out_feat))
def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class ExampleNetwork(ME.MinkowskiNetwork):
dimension=D),
ME.MinkowskiBatchNorm(128),
ME.MinkowskiReLU())
self.pooling = ME.MinkowskiGlobalPooling(dimension=D)
self.pooling = ME.MinkowskiGlobalPooling()
self.linear = ME.MinkowskiLinear(128, out_feat)

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, in_feat, out_feat, D):
kernel_size=3,
stride=2,
dimension=D), ME.MinkowskiBatchNorm(128), ME.MinkowskiReLU(),
ME.MinkowskiGlobalPooling(dimension=D),
ME.MinkowskiGlobalPooling(),
ME.MinkowskiLinear(128, out_feat))

def forward(self, x):
Expand Down
40 changes: 25 additions & 15 deletions examples/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
try:
import open3d as o3d
except ImportError:
raise ImportError(
'Please install open3d-python with `pip install open3d`.')
raise ImportError('Please install open3d-python with `pip install open3d`.')

import torch
import torch.nn as nn
Expand All @@ -41,17 +40,14 @@

import torch.nn.parallel as parallel


if not os.path.isfile('weights.pth'):
urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/1.ply", '1.ply')


parser = argparse.ArgumentParser()
parser.add_argument('--file_name', type=str, default='1.ply')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--max_ngpu', type=int, default=2)


cache = {}


Expand All @@ -71,13 +67,25 @@ def load_file(file_name, voxel_size):
return quantized_coords[inds], feats[inds], random_labels


def generate_input(file_name, voxel_size):
# Create a batch, this process is done in a data loader during training in parallel.
batch = [load_file(file_name, voxel_size)]
coordinates_, featrues_, labels_ = list(zip(*batch))
coordinates, features, labels = ME.utils.sparse_collate(
coordinates_, featrues_, labels_)

# Normalize features and create a sparse tensor
return coordinates, (features - 0.5).float(), labels


if __name__ == '__main__':
# loss and network
config = parser.parse_args()
num_devices = torch.cuda.device_count()
num_devices = min(config.max_ngpu, num_devices)
devices = list(range(num_devices))
print('Testing ', num_devices, ' GPUs. Total batch size: ', num_devices * config.batch_size)
print('Testing ', num_devices, ' GPUs. Total batch size: ',
num_devices * config.batch_size)

# For copying the final loss back to one GPU
target_device = devices[0]
Expand All @@ -87,7 +95,7 @@ def load_file(file_name, voxel_size):
net = net.to(target_device)

# Synchronized batch norm
net = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(net);
net = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(net)
optimizer = SGD(net.parameters(), lr=1e-1)

# Copy the loss layer
Expand All @@ -99,14 +107,15 @@ def load_file(file_name, voxel_size):
optimizer.zero_grad()

# Get new data
inputs, labels = [], []
inputs, all_labels = [], []
for i in range(num_devices):
batch = [load_file(config.file_name, 0.05) for _ in range(config.batch_size)]
coordinates_, featrues_, random_labels = list(zip(*batch))
coordinates, features = ME.utils.sparse_collate(coordinates_, featrues_)
coordinates, features, labels = generate_input(
config.file_name, voxel_size=0.05)
with torch.cuda.device(devices[i]):
inputs.append(ME.SparseTensor(features - 0.5, coords=coordinates).to(devices[i]))
labels.append(torch.cat(random_labels).long().to(devices[i]))
inputs.append(
ME.SparseTensor(features - 0.5,
coords=coordinates).to(devices[i]))
all_labels.append(labels.long().to(devices[i]))

# The raw version of the parallel_apply
st = time()
Expand All @@ -116,11 +125,12 @@ def load_file(file_name, voxel_size):
# Extract features from the sparse tensors to use a pytorch criterion
out_features = [output.F for output in outputs]
losses = parallel.parallel_apply(
criterions, tuple(zip(out_features, labels)), devices=devices)
criterions, tuple(zip(out_features, all_labels)), devices=devices)
loss = parallel.gather(losses, target_device, dim=0).mean()
t = time() - st
min_time = min(t, min_time)
print('Iteration: ', iteration, ', Loss: ', loss.item(), ', Time: ', t, ', Min time: ', min_time)
print('Iteration: ', iteration, ', Loss: ', loss.item(), ', Time: ', t,
', Min time: ', min_time)

# Gradient
loss.backward()
Expand Down
Loading

0 comments on commit 6a932ca

Please sign in to comment.