forked from yanwei-li/PyTorch-Encoding
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3084471
commit 8d25e1c
Showing
6 changed files
with
338 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
.. role:: hidden | ||
:class: hidden-section | ||
|
||
encoding.nn | ||
=========== | ||
|
||
Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Normalization, please visit :class:`encoding.nn.SyncBatchNorm2d`. | ||
|
||
.. currentmodule:: encoding.nn | ||
|
||
:hidden:`Encoding` | ||
~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: Encoding | ||
:members: | ||
|
||
:hidden:`SyncBatchNorm2d` | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: SyncBatchNorm2d | ||
:members: | ||
|
||
:hidden:`SyncBatchNorm1d` | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: SyncBatchNorm1d | ||
:members: | ||
|
||
:hidden:`SyncBatchNorm3d` | ||
~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: SyncBatchNorm3d | ||
:members: | ||
|
||
:hidden:`Inspiration` | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: Inspiration | ||
:members: | ||
|
||
:hidden:`UpsampleConv2d` | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: UpsampleConv2d | ||
:members: | ||
|
||
:hidden:`DilatedAvgPool2d` | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: DilatedAvgPool2d | ||
:members: | ||
|
||
:hidden:`GramMatrix` | ||
~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: GramMatrix | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# -*- coding: utf-8 -*- | ||
# File : comm.py | ||
# Author : Jiayuan Mao | ||
# Email : maojiayuan@gmail.com | ||
# Date : 27/01/2018 | ||
# | ||
# This file is part of Synchronized-BatchNorm-PyTorch. | ||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | ||
# Distributed under MIT License. | ||
|
||
import queue | ||
import collections | ||
import threading | ||
|
||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] | ||
|
||
|
||
class FutureResult(object): | ||
"""A thread-safe future implementation. Used only as one-to-one pipe.""" | ||
|
||
def __init__(self): | ||
self._result = None | ||
self._lock = threading.Lock() | ||
self._cond = threading.Condition(self._lock) | ||
|
||
def put(self, result): | ||
with self._lock: | ||
assert self._result is None, 'Previous result has\'t been fetched.' | ||
self._result = result | ||
self._cond.notify() | ||
|
||
def get(self): | ||
with self._lock: | ||
if self._result is None: | ||
self._cond.wait() | ||
|
||
res = self._result | ||
self._result = None | ||
return res | ||
|
||
|
||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) | ||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) | ||
|
||
|
||
class SlavePipe(_SlavePipeBase): | ||
"""Pipe for master-slave communication.""" | ||
|
||
def run_slave(self, msg): | ||
self.queue.put((self.identifier, msg)) | ||
ret = self.result.get() | ||
self.queue.put(True) | ||
return ret | ||
|
||
|
||
class SyncMaster(object): | ||
"""An abstract `SyncMaster` object. | ||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should | ||
call `register(id)` and obtain an `SlavePipe` to communicate with the master. | ||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, | ||
and passed to a registered callback. | ||
- After receiving the messages, the master device should gather the information and determine to message passed | ||
back to each slave devices. | ||
""" | ||
|
||
def __init__(self, master_callback): | ||
""" | ||
Args: | ||
master_callback: a callback to be invoked after having collected messages from slave devices. | ||
""" | ||
self._master_callback = master_callback | ||
self._queue = queue.Queue() | ||
self._registry = collections.OrderedDict() | ||
self._activated = False | ||
|
||
def register_slave(self, identifier): | ||
""" | ||
Register an slave device. | ||
Args: | ||
identifier: an identifier, usually is the device id. | ||
Returns: a `SlavePipe` object which can be used to communicate with the master device. | ||
""" | ||
if self._activated: | ||
assert self._queue.empty(), 'Queue is not clean before next initialization.' | ||
self._activated = False | ||
self._registry.clear() | ||
future = FutureResult() | ||
self._registry[identifier] = _MasterRegistry(future) | ||
return SlavePipe(identifier, self._queue, future) | ||
|
||
def run_master(self, master_msg): | ||
""" | ||
Main entry for the master device in each forward pass. | ||
The messages were first collected from each devices (including the master device), and then | ||
an callback will be invoked to compute the message to be sent back to each devices | ||
(including the master device). | ||
Args: | ||
master_msg: the message that the master want to send to itself. This will be placed as the first | ||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. | ||
Returns: the message to be sent back to the master device. | ||
""" | ||
self._activated = True | ||
|
||
intermediates = [(0, master_msg)] | ||
for i in range(self.nr_slaves): | ||
intermediates.append(self._queue.get()) | ||
|
||
results = self._master_callback(intermediates) | ||
assert results[0][0] == 0, 'The first result should belongs to the master.' | ||
|
||
for i, res in results: | ||
if i == 0: | ||
continue | ||
self._registry[i].result.put(res) | ||
|
||
for i in range(self.nr_slaves): | ||
assert self._queue.get() is True | ||
|
||
return results[0][1] | ||
|
||
@property | ||
def nr_slaves(self): | ||
return len(self._registry) |
Oops, something went wrong.