Skip to content

Commit

Permalink
synchronized batch norm, minor fixes on multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 16, 2019
1 parent fe836bd commit 8d86d81
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 139 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
# Change Log


## [nightly] - 2019-12-15

- Synchronized Batch Norm: `ME.MinkowskiSyncBatchNorm`
- `ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm` converts a MinkowskiNetwork automatically to use synched batch norm.
- `examples/multigpu.py` update for `ME.MinkowskiSynchBatchNorm`.
- Update multigpu documentation
- Update GIL release
- Minor error fixes on `examples/modelnet40.py`


## [0.3.1] - 2019-12-15

- Cache in-out mapping on device
- Robinhood unordered map for coordinate management
- hash based quantization to C++ CoordsManager based quantization with label collision
Expand Down
75 changes: 75 additions & 0 deletions MinkowskiEngine/MinkowskiNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,81 @@ def __repr__(self):
return self.__class__.__name__ + s


class MinkowskiSyncBatchNorm(MinkowskiBatchNorm):
r"""A batch normalization layer with multi GPU synchronization.
"""

def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
process_group=None):
Module.__init__(self)
self.bn = torch.nn.SyncBatchNorm(
num_features,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
process_group=process_group)

def forward(self, input):
# Weird requirement for the input to have > 2 dimensions which is unnecessary.
output = self.bn(input.F.unsqueeze(2)).squeeze(2)
return SparseTensor(
output,
coords_key=input.coords_key,
coords_manager=input.coords_man)

@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
r"""Helper function to convert `ME.MinkowskiBatchNorm` layer in the model to
`ME.MinkowskiSyncBatchNorm` layer.
Args:
module (nn.Module): containing module
process_group (optional): process group to scope synchronization,
default is the whole world
Returns:
The original module with the converted `ME.MinkowskiSyncBatchNorm` layer
Example::
>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>> torch.nn.Linear(20, 100),
>>> torch.nn.BatchNorm1d(100)
>>> ).cuda()
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> sync_bn_module = convert_sync_batchnorm(module, process_group)
"""
module_output = module
if isinstance(module, MinkowskiBatchNorm):
module_output = MinkowskiSyncBatchNorm(
module.bn.num_features, module.bn.eps, module.bn.momentum,
module.bn.affine, module.bn.track_running_stats, process_group)
if module.bn.affine:
module_output.bn.weight.data = module.bn.weight.data.clone().detach()
module_output.bn.bias.data = module.bn.bias.data.clone().detach()
# keep reuqires_grad unchanged
module_output.bn.weight.requires_grad = module.bn.weight.requires_grad
module_output.bn.bias.requires_grad = module.bn.bias.requires_grad
module_output.bn.running_mean = module.bn.running_mean
module_output.bn.running_var = module.bn.running_var
module_output.bn.num_batches_tracked = module.bn.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(
name, cls.convert_sync_batchnorm(child, process_group))
del module
return module_output


class MinkowskiInstanceNormFunction(Function):

@staticmethod
Expand Down
5 changes: 2 additions & 3 deletions MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@
MinkowskiPReLU, MinkowskiELU, MinkowskiSELU, MinkowskiCELU, MinkowskiDropout, \
MinkowskiThreshold, MinkowskiTanh


from MinkowskiNormalization import MinkowskiBatchNorm, MinkowskiInstanceNorm, \
MinkowskiInstanceNormFunction, MinkowskiStableInstanceNorm
from MinkowskiNormalization import MinkowskiBatchNorm, MinkowskiSyncBatchNorm, \
MinkowskiInstanceNorm, MinkowskiInstanceNormFunction, MinkowskiStableInstanceNorm

from MinkowskiPruning import MinkowskiPruning, MinkowskiPruningFunction

Expand Down
51 changes: 26 additions & 25 deletions docs/demo/multigpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,22 @@ Currently, the MinkowskiEngine supports Multi-GPU training through data parallel
Let's define a network first.

```python
import torch.nn as nn
import MinkowskiEngine as ME
from examples.minkunet import MinkUNet34C

# Copy the network to GPU
net = MinkUNet34C(3, 20, D=3)
net = net.to(target_device)
```

Synchronized Batch Norm
-----------------------

Next, we create a new network with `ME.MinkowskiSynchBatchNorm` that replaces all `ME.MinkowskiBatchNorm`. This allows the network to use the large batch size and to maintain the same performance with a single-gpu training.

class ExampleNetwork(ME.MinkowskiNetwork):

def __init__(self, in_feat, out_feat, D):
super(ExampleNetwork, self).__init__(D)
self.net = nn.Sequential(
ME.MinkowskiConvolution(
in_channels=in_feat,
out_channels=64,
kernel_size=3,
stride=2,
dilation=1,
has_bias=False,
dimension=D), ME.MinkowskiBatchNorm(64), ME.MinkowskiReLU(),
ME.MinkowskiConvolution(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
dimension=D), ME.MinkowskiBatchNorm(128), ME.MinkowskiReLU(),
ME.MinkowskiGlobalPooling(dimension=D),
ME.MinkowskiLinear(128, out_feat))

def forward(self, x):
return self.net(x)
```
# Synchronized batch norm
net = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(net);
```

Next, we need to create replicas of the network and the final loss layer (if you use one).
Expand Down Expand Up @@ -90,3 +78,16 @@ loss = parallel.gather(losses, target_device, dim=0).mean()
```

The rest of the training such as backward, and taking a step in an optimizer is similar to single-GPU training. Please refer to the [complete multi-gpu example](https://github.com/StanfordVL/MinkowskiEngine/blob/master/examples/multigpu.py) for more detail.


Speed up
--------

We use total batch size 8 on 4x Titan XP's for the experiment and will divide the load to each gpu equally. For instance, with 1 GPU, each batch will have batch size 8. With 2 GPUs, we will have 4 batches for each GPU. With 4 GPUs, each GPU will have batch size 2.


| Number of GPUs | Batch size per GPU | Time per iteration | Speedup |
|:--------------:|:------------------:|:------------------:|:-------:|
| 1 GPU | 8 | 1.611 s | x1 |
| 2 GPU | 4 | 0.916 s | x1.76 |
| 4 GPU | 2 | 0.689 s | x2.34 |
12 changes: 12 additions & 0 deletions docs/normalization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ MinkowskiBatchNorm
.. automethod:: __init__


MinkowskiSyncBatchNorm
----------------------

.. autoclass:: MinkowskiEngine.MinkowskiSyncBatchNorm
:members:
:undoc-members:
:exclude-members: forward

.. automethod:: __init__



MinkowskiInstanceNorm
---------------------

Expand Down
31 changes: 15 additions & 16 deletions examples/modelnet40.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,8 @@ def __call__(self, coords, feats):
axis = self.axis
else:
axis = np.random.rand(3) - 0.5
R = self._M(
axis,
(np.pi * self.max_theta / 180) * 2 * (np.random.rand(1) - 0.5))
R = self._M(axis, (np.pi * self.max_theta / 180) * 2 *
(np.random.rand(1) - 0.5))
return coords @ R, feats


Expand Down Expand Up @@ -359,7 +358,7 @@ def make_data_loader(phase, augment_data, batch_size, shuffle, num_workers,
return loader


def test(net, test_iter):
def test(net, test_iter, phase='val'):
net.eval()
num_correct, tot_num = 0, 0
for i in range(len(test_iter)):
Expand All @@ -373,10 +372,9 @@ def test(net, test_iter):

if i % config.stat_freq == 0:
logging.info(
f'{test_iter.dataset.phase} set iter: {i} / {len(test_iter)}, Accuracy : {num_correct / tot_num:.3e}'
f'{phase} set iter: {i} / {len(test_iter)}, Accuracy : {num_correct / tot_num:.3e}'
)
logging.info(
f'{test_iter.dataset.phase} set accuracy : {num_correct / tot_num:.3e}')
logging.info(f'{phase} set accuracy : {num_correct / tot_num:.3e}')


def train(net, device, config):
Expand Down Expand Up @@ -436,20 +434,21 @@ def train(net, device, config):

if i % config.stat_freq == 0:
logging.info(
f'Iter: {i}, Loss: {loss.item():.3e}, Data Time: {d:.3e}, Tot Time: {t:.3e}'
f'Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}'
)

if i % config.val_freq == 0 and i > 0:
torch.save({
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'curr_iter': i,
}, config.weights)
torch.save(
{
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'curr_iter': i,
}, config.weights)

# Validation
logging.info('Validation')
test(net, val_iter)
test(net, val_iter, 'val')

scheduler.step()
logging.info(f'LR: {scheduler.get_lr()}')
Expand All @@ -476,4 +475,4 @@ def train(net, device, config):
config=config)

logging.info('Test')
test(net, iter(test_dataloader))
test(net, iter(test_dataloader), 'test')
Loading

0 comments on commit 8d86d81

Please sign in to comment.