Skip to content

Commit

Permalink
pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed Apr 20, 2018
1 parent 4dcec47 commit 2e3e521
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Created by `Hang Zhang <http://hangzh.com/>`_
An optimized PyTorch package with CUDA backend.

.. note::
PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.BatchNorm2d` has been released.
Please checkout the PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.BatchNorm2d` and the `mnist example <https://github.com/zhanghang1989/PyTorch-SyncBatchNorm>`_.

.. toctree::
:glob:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notes/syncbn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Suppose we have :math:`K` number of GPUs, :math:`sum(x)_k` and :math:`sum(x^2)_k
* :math:`\frac{d_\ell}{d_{x_i}}=\frac{d_\ell}{d_{y_i}}\frac{\gamma}{\sigma}` can be calculated locally in each GPU.
* Calculate the gradient of :math:`sum(x)` and :math:`sum(x^2)` individually in each GPU :math:`\frac{d_\ell}{d_{sum(x)_k}}` and :math:`\frac{d_\ell}{d_{sum(x^2)_k}}`.

* Then Sync the gradient (automatically handled by :class:`encoding.parallel.AllReduce`) and continue the backward.
* Then Sync the gradient (automatically handled by :class:`encoding.parallel.allreduce`) and continue the backward.

.. image:: http://hangzh.com/blog/images/bn3.png
:align: center
Expand Down
6 changes: 4 additions & 2 deletions encoding/nn/syncbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter

from ..functions import batchnormtrain, batchnormeval, sum_square
from ..parallel import allreduce

#__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']

class _SyncBatchNorm(_BatchNorm):
# pylint: disable=access-member-before-definition
def __init__(self, num_features, eps=1e-5, momentum=0.1, **kwargs):
super(_SyncBatchNorm, self).__init__(num_features, eps=1e-5, momentum=0.1, **kwargs)
# syncBN
Expand Down
6 changes: 1 addition & 5 deletions encoding/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
import torch
from torch.autograd import Function
import torch.cuda.comm as comm
from torch.autograd import Variable
from torch.nn.modules import Module
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast

__all__ = ['allreduce', 'ModelDataParallel', 'CriterionDataParallel']
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion']


def allreduce(num_inputs, *inputs):
Expand Down Expand Up @@ -162,4 +159,3 @@ def _worker(i, module, input, target, kwargs, device=None):
raise output
outputs.append(output)
return outputs

0 comments on commit 2e3e521

Please sign in to comment.