Skip to content

Commit

Permalink
sync once
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed Apr 20, 2018
1 parent c6dc617 commit 4dcec47
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 295 deletions.
10 changes: 9 additions & 1 deletion docs/source/notes/syncbn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ How BN works?

BN layer was introduced in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_, which dramatically speed up the training process of the network (enables larger learning rate) and makes the network less sensitive to the weight initialization.

.. image:: http://hangzh.com/blog/images/bn1.png
:align: center

- Forward Pass:
For the input data :math:`X={x_1, ...x_N}`, the data are normalized to be zero-mean and unit-variance, then scale and shit:

Expand All @@ -31,6 +34,9 @@ Why Synchronize BN?

- Standard Implementations of BN in public frameworks (suck as Caffe, MXNet, Torch, TF, PyTorch) are unsynchronized, which means that the data are normalized within each GPU. Therefore the `working batch-size` of the BN layer is `BatchSize/nGPU` (batch-size in each GPU).

.. image:: http://hangzh.com/blog/images/bn2.png
:align: center

- Since the `working batch-size` is typically large enough for standard vision tasks, such as classification and detection, there is no need to synchronize BN layer during the training. The synchronization will slow down the training.

- However, for the Semantic Segmentation task, the state-of-the-art approaches typically adopt dilated convoluton, which is very memory consuming. The `working bath-size` can be too small for BN layers (2 or 4 in each GPU) when using larger/deeper pre-trained networks, such as :class:`encoding.dilated.ResNet` or :class:`encoding.dilated.DenseNet`.
Expand All @@ -47,8 +53,10 @@ Suppose we have :math:`K` number of GPUs, :math:`sum(x)_k` and :math:`sum(x^2)_k
* :math:`\frac{d_\ell}{d_{x_i}}=\frac{d_\ell}{d_{y_i}}\frac{\gamma}{\sigma}` can be calculated locally in each GPU.
* Calculate the gradient of :math:`sum(x)` and :math:`sum(x^2)` individually in each GPU :math:`\frac{d_\ell}{d_{sum(x)_k}}` and :math:`\frac{d_\ell}{d_{sum(x^2)_k}}`.

* Then Sync the gradient (automatically handled by :class:`encoding.parallel.allreduce`) and continue the backward.
* Then Sync the gradient (automatically handled by :class:`encoding.parallel.AllReduce`) and continue the backward.

.. image:: http://hangzh.com/blog/images/bn3.png
:align: center

Citation
--------
Expand Down
4 changes: 1 addition & 3 deletions encoding/functions/syncbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def backward(ctx, gradSum, gradSquare):


def sum_square(input):
r"""
Calculate sum of elements and sum of squares for Batch Normalization.
"""
r"""Calculate sum of elements and sum of squares for Batch Normalization"""
return _sum_square.apply(input)


Expand Down
245 changes: 55 additions & 190 deletions encoding/nn/syncbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,136 +13,66 @@
import torch
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter

from ..functions import batchnormtrain, batchnormeval, sum_square
from ..parallel import allreduce

# import standard layers for convinent use
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d',
'AvgPool2d', 'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
#__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']

class BatchNorm1d(Module):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # Use exactly the same as standard BatchNrom1d
>>> m = nn.BatchNorm1d(100)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm1d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
class _SyncBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, **kwargs):
super(_SyncBatchNorm, self).__init__(num_features, eps=1e-5, momentum=0.1, **kwargs)
# syncBN
self.writelock = threading.Lock()
nGPUs = torch.cuda.device_count()
self.xsum = SharedTensor(nGPUs)
self.xsquare = SharedTensor(nGPUs)

def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()

def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))

def _check_input_dim(self, input):
if input.dim() != 3:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
self.sharedT = SharedTensor(nGPUs)

def forward(self, input):
self._check_input_dim(input)
if self.training:
# push the value
isum, isquare = sum_square(input.unsqueeze(3))
idxs = self.xsum.push(isum)
idxq = self.xsquare.push(isquare)
xsum = self.xsum[idxs]
xsquare = self.xsquare[idxq]
# calculate N
N = len(self.xsum)*input.size(0)*input.size(2)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
std = (sumvar / N + self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean \
+ self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + \
self.momentum * unbias_var.data
# forward
return batchnormtrain(input, self.weight,
self.bias, mean, std)
else:
std = (self.running_var + self.eps).sqrt()
return batchnormeval(input, self.weight, self.bias,
self.running_mean, std)

input_shape = input.size()
input = input.view(input_shape[0], self.num_features, -1)
if not self.training:
std = (self.running_var.clamp(self.eps)).sqrt()
output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
return output.view(input_shape)
# get global sum(x) and sum(x^2)
xsum, xsquare = self.sharedT(sum_square(input.unsqueeze(3)))
# calculate mean, var
N = len(self.sharedT) * input.size(0) * input.size(2)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
bias_var = sumvar / N
std = bias_var.clamp(self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data
# forward
return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape)


class BatchNorm1d(_SyncBatchNorm):
r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))

class BatchNorm2d(Module):
class BatchNorm2d(_SyncBatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-dimension over
The mean and standard-deviation are calculated per-channel over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
Expand Down Expand Up @@ -177,78 +107,20 @@ class BatchNorm2d(Module):
>>> m = nn.BatchNorm2d(100)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm2d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
self.writelock = threading.Lock()
nGPUs = torch.cuda.device_count()
self.xsum, self.xsquare = SharedTensor(nGPUs), SharedTensor(nGPUs)

def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()

def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))

def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))

def forward(self, input):
self._check_input_dim(input)
if self.training:
# push the value
isum, isquare = sum_square(input)
idxs = self.xsum.push(isum)
idxq = self.xsquare.push(isquare)
xsum = self.xsum[idxs]
xsquare = self.xsquare[idxq]
# calculate N
N = len(self.xsum)*input.size(0)*input.size(2)*input.size(3)
mean = xsum / N
sumvar = xsquare - xsum * xsum / N
unbias_var = sumvar / (N - 1)
std = (sumvar / N + self.eps).sqrt()
# update running_mean and var
self.running_mean = (1-self.momentum) * self.running_mean \
+ self.momentum * mean.data
self.running_var = (1-self.momentum) * self.running_var + \
self.momentum * unbias_var.data
# forward
B, C, H, W = input.size()
output = batchnormtrain(
input.view(B, C, -1).contiguous(), self.weight,
self.bias, mean,
std)
return output.view(B, C, H, W)
else:
std = (self.running_var + self.eps).sqrt()
B, C, H, W = input.size()
return batchnormeval(input.view(B, C, -1).contiguous(), self.weight, self.bias,
self.running_mean, std).view(B, C, H, W)

class BatchNorm3d(_SyncBatchNorm):
r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))

class SharedTensor(object):
"""Shared Tensor
"""Shared Tensor for cross GPU communication
"""
def __init__(self, nGPUs):
self.mutex = threading.Lock()
Expand All @@ -261,44 +133,37 @@ def _clear(self):
self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs

def push(self, t):
"""push a Tensor
"""
def __call__(self, *inputs):
# push from device
with self.mutex:
if self.push_tasks == 0:
self._clear()
self.list.append(t)
idx = len(self.list) - 1
self.list.extend(list(*inputs))
idx = self.nGPUs - self.push_tasks
self.push_tasks -= 1

with self.all_tasks_done:
if self.push_tasks == 0:
self.all_tasks_done.notify_all()
while self.push_tasks:
self.all_tasks_done.wait()
return idx

def _reduce(self):
# pull from device
with self.mutex:
if self.reduce_tasks == self.nGPUs:
assert(len(self.list) == self.nGPUs)
self.outlist = allreduce(*self.list)
assert(len(self.list) == 2 * self.nGPUs)
self.list = allreduce(2, *self.list)
self.reduce_tasks -= 1
else:
self.reduce_tasks -= 1

with self.all_tasks_done:
if self.reduce_tasks == 0:
self.all_tasks_done.notify_all()
while self.reduce_tasks:
self.all_tasks_done.wait()

def __getitem__(self, idx):
self._reduce()
return self.outlist[idx]
# all reduce done
return self.list[2*idx], self.list[2*idx+1]

def __len__(self):
return len(self.list)
return self.nGPUs

def __repr__(self):
return ('SharedTensor')
Loading

0 comments on commit 4dcec47

Please sign in to comment.