Skip to content

Commit

Permalink
update and fix bugs (zhanghang1989#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 15, 2018
1 parent 71447e1 commit 67e153d
Show file tree
Hide file tree
Showing 20 changed files with 617 additions and 167 deletions.
19 changes: 10 additions & 9 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torch.utils.ffi import create_extension

lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
cwd = os.path.dirname(os.path.realpath(__file__))
encoding_lib_path = os.path.join(cwd, "encoding", "lib")
cwd = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'encoding/')
encoding_lib_path = os.path.join(cwd, "lib")

# clean the build files
clean_cmd = ['bash', 'clean.sh']
Expand All @@ -25,13 +25,13 @@
# build CUDA library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.1.dylib')
ENCODING_LIB = os.path.join(cwd, 'encoding/lib/libENCODING.dylib')
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.dylib')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.dylib')

else:
os.environ['CFLAGS'] = '-std=c99'
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so.1')
ENCODING_LIB = os.path.join(cwd, 'encoding/lib/libENCODING.so')
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libATen.so')
ENCODING_LIB = os.path.join(cwd, 'lib/libENCODING.so')

build_all_cmd = ['bash', 'encoding/make.sh']
subprocess.check_call(build_all_cmd, env=dict(os.environ))
Expand All @@ -45,9 +45,9 @@
with_cuda = True

include_path = [os.path.join(lib_path, 'include'),
os.path.join(cwd,'encoding/kernel'),
os.path.join(cwd,'encoding/kernel/include'),
os.path.join(cwd,'encoding/src/')]
os.path.join(cwd,'kernel'),
os.path.join(cwd,'kernel/include'),
os.path.join(cwd,'src/')]

def make_relative_rpath(path):
if platform.system() == 'Darwin':
Expand All @@ -63,6 +63,7 @@ def make_relative_rpath(path):
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_compile_args=["-std=c99"],
include_dirs = include_path,
extra_link_args = [
make_relative_rpath(lib_path),
Expand Down
3 changes: 1 addition & 2 deletions clean.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
#!/usr/bin/env bash

rm -rf build/ dist/ encoding.egg-info/ encoding/lib/ encoding/_ext/ __pycache__ encoding/__pycache__
rm -rf build/ dist/ torch_encoding.egg-info/ encoding/lib/ encoding/_ext/ __pycache__ encoding/__pycache__
2 changes: 1 addition & 1 deletion docs/source/dilated.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. role:: hidden
:class: hidden-section

Dilated Networks
encoding.dilated
================

We provide correct dilated pre-trained ResNet and DenseNet (stride of 8) for semantic segmentation.
Expand Down
12 changes: 11 additions & 1 deletion docs/source/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@
encoding.functions
==================

.. automodule:: encoding.functions
.. automodule:: encoding.Functions

.. currentmodule:: encoding.functions


:hidden:`batchnorm`
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: batchnorm

:hidden:`batchnormeval`
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: batchnormeval
:hidden:`dilatedavgpool2d`
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 2 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Created by `Hang Zhang <http://hangzh.com/>`_
An optimized PyTorch package with CUDA backend.

.. note::
Please checkout the PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.BatchNorm2d` and the `mnist example <https://github.com/zhanghang1989/PyTorch-SyncBatchNorm>`_.
PyTorch compatible Synchronized Cross-GPU :class:`encoding.nn.SyncBatchNorm2d` and the `MNIST example <https://github.com/zhanghang1989/PyTorch-SyncBatchNorm>`_.

.. toctree::
:glob:
Expand All @@ -30,8 +30,7 @@ An optimized PyTorch package with CUDA backend.
:maxdepth: 1
:caption: Package Reference

encoding
syncbn
nn
parallel
dilated
functions
Expand Down
8 changes: 3 additions & 5 deletions docs/source/notes/compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,20 @@ Install and Citations
Install from Source
-------------------

* Install PyTorch from Source (recommended). Please follow the `PyTorch instructions <https://github.com/pytorch/pytorch#from-source>`_.

* Install this package
* Install PyTorch by following the `PyTorch instructions <http://pytorch.org/>`_.
* Install from source

- Clone the repo::

git clone https://github.com/zhanghang1989/PyTorch-Encoding && cd PyTorch-Encoding

- On Linux::

pip install -r requirements.txt
python setup.py install

- On Mac OSX::

pip install -r requirements.txt
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install

Citations
Expand Down
14 changes: 7 additions & 7 deletions docs/source/parallel.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
.. role:: hidden
:class: hidden-section

Data Parallel
=============
encoding.parallel
=================

- Current PyTorch DataParallel Table is not supporting mutl-gpu loss calculation, which makes the gpu memory usage very in-balance. We address this issue here by doing Model & CriterionDataParallel.
- Current PyTorch DataParallel Table is not supporting mutl-gpu loss calculation, which makes the gpu memory usage very in-balance. We address this issue here by doing DataParallel for Model & Criterion.

.. note::
This code is provided together with the paper
Expand All @@ -15,16 +15,16 @@ Data Parallel
.. automodule:: encoding.parallel
.. currentmodule:: encoding.parallel

:hidden:`ModelDataParallel`
:hidden:`DataParallelModel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ModelDataParallel
.. autoclass:: DataParallelModel
:members:

:hidden:`CriterionDataParallel`
:hidden:`DataParallelCriterion`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: CriterionDataParallel
.. autoclass:: DataParallelCriterion
:members:


Expand Down
20 changes: 12 additions & 8 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
.. role:: hidden
:class: hidden-section

My PyTorch Utils
================
encoding.utils
==============

Useful util functions.

.. automodule:: encoding.utils
.. currentmodule:: encoding.utils

:hidden:`LR_Scheduler`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LR_Scheduler
:members:

:hidden:`get_optimizer`
~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -24,3 +18,13 @@ Useful util functions.
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: save_checkpoint

:hidden:`batch_pix_accuracy`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: batch_pix_accuracy

:hidden:`batch_intersection_union`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: batch_intersection_union
49 changes: 26 additions & 23 deletions encoding/dilated/resnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Dilated ResNet"""
import math
import torch.utils.model_zoo as model_zoo
from .. import nn
#from .. import nn
import torch.nn as nn

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'BasicBlock', 'Bottleneck']
Expand All @@ -25,15 +26,16 @@ class BasicBlock(nn.Module):
"""ResNet BasicBlock
"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, first_dilation=1):
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, first_dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=first_dilation, dilation=first_dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

Expand Down Expand Up @@ -62,18 +64,18 @@ class Bottleneck(nn.Module):
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1,
downsample=None, first_dilation=1):
downsample=None, first_dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(
planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.bn3 = norm_layer(planes * 4)
self.relu = nn.ReLU(inplace=False)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
Expand Down Expand Up @@ -118,51 +120,52 @@ class ResNet(nn.Module):
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
# pylint: disable=unused-variable
def __init__(self, block, layers, num_classes=1000):
def __init__(self, block, layers, num_classes=1000, norm_layer=None):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.bn1 = norm_layer(64)
self.relu = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, norm_layer):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
norm_layer(planes * block.expansion),
)

layers = []
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, stride, dilation=1,
downsample=downsample, first_dilation=dilation))
downsample=downsample, first_dilation=dilation, norm_layer=norm_layer))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, dilation=2,
downsample=downsample, first_dilation=dilation))
downsample=downsample, first_dilation=dilation, norm_layer=norm_layer))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))

self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation, first_dilation=dilation))
layers.append(block(self.inplanes, planes, dilation=dilation, first_dilation=dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

Expand Down
32 changes: 16 additions & 16 deletions encoding/functions/syncbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,54 @@
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Synchronized Batch Normalization functions"""
"""Synchronized Cross-GPU Batch Normalization functions"""
import torch
from torch.autograd import Function, Variable
from torch.autograd import Variable, Function
from .._ext import encoding_lib

__all__ = ['sum_square', 'batchnormtrain', 'batchnormeval']

def sum_square(input):
r"""Calculate sum of elements and sum of squares for Batch Normalization"""
return _sum_square.apply(input)


class _sum_square(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
B, C, _, _ = input.size()
C = input.size(1)
with torch.cuda.device_of(input):
xsum = input.new().resize_(C).zero_()
xsquare = input.new().resize_(C).zero_()
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_sum_square_Forward(
input.view(B, C, -1), xsum, xsquare)
input, xsum, xsquare)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_sum_square_Forward(
input.view(B, C, -1), xsum, xsquare)
input, xsum, xsquare)
else:
raise RuntimeError('Unimplemented data type!')
raise RuntimeError('Unimplemented data type!', type(input))
return xsum, xsquare

@staticmethod
def backward(ctx, gradSum, gradSquare):
input, = ctx.saved_variables
B, C, H, W = input.data.size()
with torch.cuda.device_of(input.data):
gradInput = Variable(input.data.new().resize_(B, C, H*W).zero_())
gradInput = Variable(input.data.new().resize_as_(input.data).zero_())
if isinstance(input.data, torch.cuda.FloatTensor):
with torch.cuda.device_of(input.data):
encoding_lib.Encoding_Float_sum_square_Backward(
gradInput, input.data.view(B, C, -1), gradSum, gradSquare)
gradInput.data, input.data, gradSum.data, gradSquare.data)
elif isinstance(input.data, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input.data):
encoding_lib.Encoding_Double_sum_square_Backward(
gradInput, input.data.view(B, C, -1), gradSum, gradSquare)
gradInput.data, input.data, gradSum.data, gradSquare.data)
else:
raise RuntimeError('Unimplemented data type!')
return gradInput.view(B, C, H, W)


def sum_square(input):
r"""Calculate sum of elements and sum of squares for Batch Normalization"""
return _sum_square.apply(input)
return gradInput


class _batchnorm(Function):
Expand Down Expand Up @@ -134,3 +133,4 @@ def batchnormeval(input, gamma, beta, mean, std):
Please see encoding.batchnormtrain_
"""
return _batchnorm(False)(input, gamma, beta, mean, std)

Loading

0 comments on commit 67e153d

Please sign in to comment.