From cebf13414afb9aafbe28ff25fe7fcd256051bf07 Mon Sep 17 00:00:00 2001
From: Hang Zhang <8041160+zhanghang1989@users.noreply.github.com>
Date: Tue, 15 May 2018 21:41:38 -0700
Subject: [PATCH] Adapt SyncBN API from Other's Work (#52)
* update and fix bugs
* adapt syncbn api from other work
* typo
---
build.py | 12 +++-
docs/source/functions.rst | 2 +-
docs/source/index.rst | 2 -
docs/source/nn.rst | 57 +++++++++++++++++
encoding/nn/comm.py | 131 ++++++++++++++++++++++++++++++++++++++
encoding/nn/syncbn.py | 102 ++++++++++++++++++++---------
encoding/parallel.py | 69 +++++++++++++++++++-
7 files changed, 339 insertions(+), 36 deletions(-)
create mode 100644 docs/source/nn.rst
create mode 100644 encoding/nn/comm.py
diff --git a/build.py b/build.py
index 0c5375fa..489e69ef 100644
--- a/build.py
+++ b/build.py
@@ -14,6 +14,8 @@
import subprocess
from torch.utils.ffi import create_extension
+torch_ver = torch.__version__[:3]
+
lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/')
encoding_lib_path = os.path.join(cwd, "lib")
@@ -25,12 +27,18 @@
# build CUDA library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
- os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
+ if torch_ver == '0.3':
+ os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib')
+ else:
+ os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib')
else:
os.environ['CFLAGS'] = '-std=c99'
- os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
+ if torch_ver == '0.3':
+ os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1')
+ else:
+ os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so')
build_all_cmd = ['bash', 'encoding/make.sh']
diff --git a/docs/source/functions.rst b/docs/source/functions.rst
index 3b083dae..144a7a24 100644
--- a/docs/source/functions.rst
+++ b/docs/source/functions.rst
@@ -4,7 +4,7 @@
encoding.functions
==================
-.. automodule:: encoding.Functions
+.. automodule:: encoding.functions
.. currentmodule:: encoding.functions
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 9b84e66a..5302df1f 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -9,8 +9,6 @@ Created by `Hang Zhang `_
An optimized PyTorch package with CUDA backend.
-.. note::
- PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.SyncBatchNorm2d` and the `MNIST example `_.
.. toctree::
:glob:
diff --git a/docs/source/nn.rst b/docs/source/nn.rst
new file mode 100644
index 00000000..dd0a6450
--- /dev/null
+++ b/docs/source/nn.rst
@@ -0,0 +1,57 @@
+.. role:: hidden
+ :class: hidden-section
+
+encoding.nn
+===========
+
+Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Normalization, please visit :class:`encoding.nn.SyncBatchNorm2d`.
+
+.. currentmodule:: encoding.nn
+
+:hidden:`Encoding`
+~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: Encoding
+ :members:
+
+:hidden:`SyncBatchNorm2d`
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: SyncBatchNorm2d
+ :members:
+
+:hidden:`SyncBatchNorm1d`
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: SyncBatchNorm1d
+ :members:
+
+:hidden:`SyncBatchNorm3d`
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: SyncBatchNorm3d
+ :members:
+
+:hidden:`Inspiration`
+~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: Inspiration
+ :members:
+
+:hidden:`UpsampleConv2d`
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: UpsampleConv2d
+ :members:
+
+:hidden:`DilatedAvgPool2d`
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: DilatedAvgPool2d
+ :members:
+
+:hidden:`GramMatrix`
+~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: GramMatrix
+ :members:
diff --git a/encoding/nn/comm.py b/encoding/nn/comm.py
new file mode 100644
index 00000000..b64bf6ba
--- /dev/null
+++ b/encoding/nn/comm.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/encoding/nn/syncbn.py b/encoding/nn/syncbn.py
index 53e6c2b5..7005789a 100644
--- a/encoding/nn/syncbn.py
+++ b/encoding/nn/syncbn.py
@@ -9,7 +9,6 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Synchronized Cross-GPU Batch Normalization Module"""
-import functools
import collections
import threading
import torch
@@ -22,52 +21,96 @@
from ..functions import *
from ..parallel import allreduce
+from .comm import SyncMaster
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
+# Adapt from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+_ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SyncBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
- self.sharedT = SharedTensor(torch.cuda.device_count())
def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
# Resize the input to (B, C, -1).
input_shape = input.size()
- input = input.view(input_shape[0], self.num_features, -1)
- if not self.training:
- std = (self.running_var.clamp(self.eps)).sqrt()
- output = batchnormeval(input, self.weight, self.bias, self.running_mean, std)
- return output.view(input_shape)
+ input = input.view(input.size(0), self.num_features, -1)
# sum(x) and sum(x^2)
N = input.size(0) * input.size(2)
xsum, xsqsum = sum_square(input)
# all-reduce for global sum(x) and sum(x^2)
- igpu = input.get_device()
- self.sharedT.push(N, igpu, xsum, xsqsum)
- N, xsum, xsqsum = self.sharedT.pull(igpu)
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N))
+ # forward
+ return batchnormtrain(input, self.weight, self.bias, mean, 1.0/inv_std).view(input_shape)
- # calculate mean, var
- mean = xsum / N
- sumvar = xsqsum - xsum * xsum / N
- unbias_var = sumvar / (N - 1)
- bias_var = sumvar / N
- std = bias_var.clamp(self.eps).sqrt()
- # update running_mean and var
- self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data
- self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
- # forward
- return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape)
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+ # Always using same "device order" makes the ReduceAdd operation faster.
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
class BatchNorm1d(_SyncBatchNorm):
@@ -82,13 +125,15 @@ def _check_input_dim(self, input):
class BatchNorm2d(_SyncBatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
- Standard BN [1]_ implementation only normalize the data within each device.
+ Standard BN [1]_ implementation only normalize the data within each device (GPU).
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. note::
- Please use ``CUDA_VISIBLE_DEVICES`` to select number of GPUs.
+ We adapt the awesome python API from another `PyTorch SyncBN Implementation
+ `_ and provide
+ efficient CUDA backend.
.. math::
@@ -125,9 +170,9 @@ class BatchNorm2d(_SyncBatchNorm):
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples:
- >>> # Use exactly the same as standard BatchNrom2d
>>> m = BatchNorm2d(100)
>>> net = torch.nn.DataParallel(m)
+ >>> encoding.parallel.patch_replication_callback(net)
>>> output = net(input)
"""
def _check_input_dim(self, input):
@@ -148,11 +193,12 @@ def _check_input_dim(self, input):
class SharedTensor(object):
"""Shared Tensor for cross GPU all reduce operation"""
- def __init__(self, nGPUs):
+ def __init__(self, nGPUs, op):
self.mutex = threading.Lock()
self.all_tasks_done = threading.Condition(self.mutex)
self.nGPUs = nGPUs
self._clear()
+ self.op = op
def _clear(self):
self.N = 0
@@ -160,7 +206,7 @@ def _clear(self):
self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs
- def push(self, *inputs):
+ def __call__(self, *inputs):
if self.nGPUs <= 1:
return tuple(inputs)
# push from device
@@ -177,15 +223,13 @@ def push(self, *inputs):
self.all_tasks_done.notify_all()
while self.push_tasks:
self.all_tasks_done.wait()
-
- def pull(self, igpu):
# pull from device
with self.mutex:
if igpu == 0:
assert(len(self.dict) == self.nGPUs)
# flatten the tensors
self.list = [t for i in range(len(self.dict)) for t in self.dict[i]]
- self.outlist = allreduce(2, *self.list)
+ self.outlist = self.op(2, *self.list)
self.reduce_tasks -= 1
else:
self.reduce_tasks -= 1
diff --git a/encoding/parallel.py b/encoding/parallel.py
index 3b521132..b979ca28 100644
--- a/encoding/parallel.py
+++ b/encoding/parallel.py
@@ -10,6 +10,7 @@
"""Encoding Data Parallel"""
import threading
+import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
@@ -17,10 +18,11 @@
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
-__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion']
-
torch_ver = torch.__version__[:3]
+__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
+ 'patch_replication_callback']
+
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
@@ -94,6 +96,11 @@ class DataParallelModel(DataParallel):
def gather(self, outputs, output_device):
return outputs
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelModel, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
class DataParallelCriterion(DataParallel):
"""
@@ -181,3 +188,61 @@ def _worker(i, module, input, target, kwargs, device=None):
raise output
outputs.append(output)
return outputs
+
+###########################################################################
+# Adapted from Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created
+ by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead
+ of calling the callback of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate