Skip to content

Commit

Permalink
Adapt SyncBN API from Other's Work (zhanghang1989#52)
Browse files Browse the repository at this point in the history
* update and fix bugs

* adapt syncbn api from other work

* typo
  • Loading branch information
zhanghang1989 committed May 16, 2018
1 parent 67e153d commit cebf134
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 36 deletions.
12 changes: 10 additions & 2 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import subprocess
from torch.utils.ffi import create_extension

torch_ver = torch.__version__[:3]

lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/')
encoding_lib_path = os.path.join(cwd, "lib")
Expand All @@ -25,12 +27,18 @@
# build CUDA library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
if torch_ver == '0.3':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib')

else:
os.environ['CFLAGS'] = '-std=c99'
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
if torch_ver == '0.3':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so')

build_all_cmd = ['bash', 'encoding/make.sh']
Expand Down
2 changes: 1 addition & 1 deletion docs/source/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
encoding.functions
==================

.. automodule:: encoding.Functions
.. automodule:: encoding.functions

.. currentmodule:: encoding.functions

Expand Down
2 changes: 0 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ Created by `Hang Zhang <http://hangzh.com/>`_

An optimized PyTorch package with CUDA backend.

.. note::
PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.SyncBatchNorm2d` and the `MNIST example <https://github.com/zhanghang1989/PyTorch-SyncBatchNorm>`_.

.. toctree::
:glob:
Expand Down
57 changes: 57 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
.. role:: hidden
:class: hidden-section

encoding.nn
===========

Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Normalization, please visit :class:`encoding.nn.SyncBatchNorm2d`.

.. currentmodule:: encoding.nn

:hidden:`Encoding`
~~~~~~~~~~~~~~~~~~

.. autoclass:: Encoding
:members:

:hidden:`SyncBatchNorm2d`
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SyncBatchNorm2d
:members:

:hidden:`SyncBatchNorm1d`
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SyncBatchNorm1d
:members:

:hidden:`SyncBatchNorm3d`
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SyncBatchNorm3d
:members:

:hidden:`Inspiration`
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Inspiration
:members:

:hidden:`UpsampleConv2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: UpsampleConv2d
:members:

:hidden:`DilatedAvgPool2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DilatedAvgPool2d
:members:

:hidden:`GramMatrix`
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GramMatrix
:members:
131 changes: 131 additions & 0 deletions encoding/nn/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

import queue
import collections
import threading

__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']


class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""

def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)

def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()

def get(self):
with self._lock:
if self._result is None:
self._cond.wait()

res = self._result
self._result = None
return res


_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])


class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""

def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret


class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""

def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False

def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)

def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True

intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())

results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'

for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)

for i in range(self.nr_slaves):
assert self._queue.get() is True

return results[0][1]

@property
def nr_slaves(self):
return len(self._registry)
Loading

0 comments on commit cebf134

Please sign in to comment.