From cebf13414afb9aafbe28ff25fe7fcd256051bf07 Mon Sep 17 00:00:00 2001 From: Hang Zhang <8041160+zhanghang1989@users.noreply.github.com> Date: Tue, 15 May 2018 21:41:38 -0700 Subject: [PATCH] Adapt SyncBN API from Other's Work (#52) * update and fix bugs * adapt syncbn api from other work * typo --- build.py | 12 +++- docs/source/functions.rst | 2 +- docs/source/index.rst | 2 - docs/source/nn.rst | 57 +++++++++++++++++ encoding/nn/comm.py | 131 ++++++++++++++++++++++++++++++++++++++ encoding/nn/syncbn.py | 102 ++++++++++++++++++++--------- encoding/parallel.py | 69 +++++++++++++++++++- 7 files changed, 339 insertions(+), 36 deletions(-) create mode 100644 docs/source/nn.rst create mode 100644 encoding/nn/comm.py diff --git a/build.py b/build.py index 0c5375fa..489e69ef 100644 --- a/build.py +++ b/build.py @@ -14,6 +14,8 @@ import subprocess from torch.utils.ffi import create_extension +torch_ver = torch.__version__[:3] + lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib') cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/') encoding_lib_path = os.path.join(cwd, "lib") @@ -25,12 +27,18 @@ # build CUDA library os.environ['TORCH_BUILD_DIR'] = lib_path if platform.system() == 'Darwin': - os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib') + if torch_ver == '0.3': + os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib') + else: + os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib') ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib') else: os.environ['CFLAGS'] = '-std=c99' - os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so') + if torch_ver == '0.3': + os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1') + else: + os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so') ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so') build_all_cmd = ['bash', 'encoding/make.sh'] diff --git a/docs/source/functions.rst b/docs/source/functions.rst index 3b083dae..144a7a24 100644 --- a/docs/source/functions.rst +++ b/docs/source/functions.rst @@ -4,7 +4,7 @@ encoding.functions ================== -.. automodule:: encoding.Functions +.. automodule:: encoding.functions .. currentmodule:: encoding.functions diff --git a/docs/source/index.rst b/docs/source/index.rst index 9b84e66a..5302df1f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,8 +9,6 @@ Created by `Hang Zhang `_ An optimized PyTorch package with CUDA backend. -.. note:: - PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.SyncBatchNorm2d` and the `MNIST example `_. .. toctree:: :glob: diff --git a/docs/source/nn.rst b/docs/source/nn.rst new file mode 100644 index 00000000..dd0a6450 --- /dev/null +++ b/docs/source/nn.rst @@ -0,0 +1,57 @@ +.. role:: hidden + :class: hidden-section + +encoding.nn +=========== + +Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Normalization, please visit :class:`encoding.nn.SyncBatchNorm2d`. + +.. currentmodule:: encoding.nn + +:hidden:`Encoding` +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Encoding + :members: + +:hidden:`SyncBatchNorm2d` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SyncBatchNorm2d + :members: + +:hidden:`SyncBatchNorm1d` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SyncBatchNorm1d + :members: + +:hidden:`SyncBatchNorm3d` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SyncBatchNorm3d + :members: + +:hidden:`Inspiration` +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Inspiration + :members: + +:hidden:`UpsampleConv2d` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: UpsampleConv2d + :members: + +:hidden:`DilatedAvgPool2d` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DilatedAvgPool2d + :members: + +:hidden:`GramMatrix` +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GramMatrix + :members: diff --git a/encoding/nn/comm.py b/encoding/nn/comm.py new file mode 100644 index 00000000..b64bf6ba --- /dev/null +++ b/encoding/nn/comm.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/encoding/nn/syncbn.py b/encoding/nn/syncbn.py index 53e6c2b5..7005789a 100644 --- a/encoding/nn/syncbn.py +++ b/encoding/nn/syncbn.py @@ -9,7 +9,6 @@ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """Synchronized Cross-GPU Batch Normalization Module""" -import functools import collections import threading import torch @@ -22,52 +21,96 @@ from ..functions import * from ..parallel import allreduce +from .comm import SyncMaster __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d', 'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d', 'AdaptiveAvgPool2d', 'Dropout2d', 'Linear'] +# Adapt from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +_ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) class _SyncBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + self._sync_master = SyncMaster(self._data_parallel_master) + self._is_parallel = False self._parallel_id = None self._slave_pipe = None - self.sharedT = SharedTensor(torch.cuda.device_count()) def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + # Resize the input to (B, C, -1). input_shape = input.size() - input = input.view(input_shape[0], self.num_features, -1) - if not self.training: - std = (self.running_var.clamp(self.eps)).sqrt() - output = batchnormeval(input, self.weight, self.bias, self.running_mean, std) - return output.view(input_shape) + input = input.view(input.size(0), self.num_features, -1) # sum(x) and sum(x^2) N = input.size(0) * input.size(2) xsum, xsqsum = sum_square(input) # all-reduce for global sum(x) and sum(x^2) - igpu = input.get_device() - self.sharedT.push(N, igpu, xsum, xsqsum) - N, xsum, xsqsum = self.sharedT.pull(igpu) + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N)) + # forward + return batchnormtrain(input, self.weight, self.bias, mean, 1.0/inv_std).view(input_shape) - # calculate mean, var - mean = xsum / N - sumvar = xsqsum - xsum * xsum / N - unbias_var = sumvar / (N - 1) - bias_var = sumvar / N - std = bias_var.clamp(self.eps).sqrt() - # update running_mean and var - self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1-self.momentum) * self.running_var + self.momentum * unbias_var.data + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id - # forward - return batchnormtrain(input, self.weight, self.bias, mean, std).view(input_shape) + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 class BatchNorm1d(_SyncBatchNorm): @@ -82,13 +125,15 @@ def _check_input_dim(self, input): class BatchNorm2d(_SyncBatchNorm): r"""Cross-GPU Synchronized Batch normalization (SyncBN) - Standard BN [1]_ implementation only normalize the data within each device. + Standard BN [1]_ implementation only normalize the data within each device (GPU). SyncBN normalizes the input within the whole mini-batch. We follow the sync-onece implmentation described in the paper [2]_ . Please see the design idea in the `notes <./notes/syncbn.html>`_. .. note:: - Please use ``CUDA_VISIBLE_DEVICES`` to select number of GPUs. + We adapt the awesome python API from another `PyTorch SyncBN Implementation + `_ and provide + efficient CUDA backend. .. math:: @@ -125,9 +170,9 @@ class BatchNorm2d(_SyncBatchNorm): .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* Examples: - >>> # Use exactly the same as standard BatchNrom2d >>> m = BatchNorm2d(100) >>> net = torch.nn.DataParallel(m) + >>> encoding.parallel.patch_replication_callback(net) >>> output = net(input) """ def _check_input_dim(self, input): @@ -148,11 +193,12 @@ def _check_input_dim(self, input): class SharedTensor(object): """Shared Tensor for cross GPU all reduce operation""" - def __init__(self, nGPUs): + def __init__(self, nGPUs, op): self.mutex = threading.Lock() self.all_tasks_done = threading.Condition(self.mutex) self.nGPUs = nGPUs self._clear() + self.op = op def _clear(self): self.N = 0 @@ -160,7 +206,7 @@ def _clear(self): self.push_tasks = self.nGPUs self.reduce_tasks = self.nGPUs - def push(self, *inputs): + def __call__(self, *inputs): if self.nGPUs <= 1: return tuple(inputs) # push from device @@ -177,15 +223,13 @@ def push(self, *inputs): self.all_tasks_done.notify_all() while self.push_tasks: self.all_tasks_done.wait() - - def pull(self, igpu): # pull from device with self.mutex: if igpu == 0: assert(len(self.dict) == self.nGPUs) # flatten the tensors self.list = [t for i in range(len(self.dict)) for t in self.dict[i]] - self.outlist = allreduce(2, *self.list) + self.outlist = self.op(2, *self.list) self.reduce_tasks -= 1 else: self.reduce_tasks -= 1 diff --git a/encoding/parallel.py b/encoding/parallel.py index 3b521132..b979ca28 100644 --- a/encoding/parallel.py +++ b/encoding/parallel.py @@ -10,6 +10,7 @@ """Encoding Data Parallel""" import threading +import functools import torch from torch.autograd import Variable, Function import torch.cuda.comm as comm @@ -17,10 +18,11 @@ from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast -__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion'] - torch_ver = torch.__version__[:3] +__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', + 'patch_replication_callback'] + def allreduce(*inputs): """Cross GPU all reduce autograd operation for calculate mean and variance in SyncBN. @@ -94,6 +96,11 @@ class DataParallelModel(DataParallel): def gather(self, outputs, output_device): return outputs + def replicate(self, module, device_ids): + modules = super(DataParallelModel, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + class DataParallelCriterion(DataParallel): """ @@ -181,3 +188,61 @@ def _worker(i, module, input, target, kwargs, device=None): raise output outputs.append(output) return outputs + +########################################################################### +# Adapted from Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created + by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead + of calling the callback of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate