Skip to content

Commit

Permalink
SyncBN backend and double type
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed Sep 18, 2017
1 parent 1633f31 commit fa0e478
Show file tree
Hide file tree
Showing 13 changed files with 706 additions and 26 deletions.
150 changes: 146 additions & 4 deletions encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import threading
import torch
import torch.cuda.nccl as nccl
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import Function, Variable
from torch.nn.parameter import Parameter
from ._ext import encoding_lib

class aggregate(Function):
Expand All @@ -20,15 +23,26 @@ def forward(self, A, R):
B, N, K, D = R.size()
E = A.new(B,K,D)
# TODO support cpu backend
encoding_lib.Encoding_Float_aggregate_forward(E, A, R)
if isinstance(A, torch.cuda.FloatTensor):
encoding_lib.Encoding_Float_aggregate_forward(E, A, R)
elif isinstance(A, torch.cuda.DoubleTensor):
encoding_lib.Encoding_Double_aggregate_forward(E, A, R)
else:
raise RuntimeError('unimplemented')
return E

def backward(self, gradE):
A, R = self.saved_tensors
gradA = A.new().resize_as_(A)
gradR = R.new().resize_as_(R)
encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE,
A, R)
if isinstance(A, torch.cuda.FloatTensor):
encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE,
A, R)
elif isinstance(A, torch.cuda.DoubleTensor):
encoding_lib.Encoding_Double_aggregate_backward(gradA, gradR, gradE,
A, R)
else:
raise RuntimeError('unimplemented')
return gradA, gradR


Expand Down Expand Up @@ -82,3 +96,131 @@ def forward(self, X):
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' + str(self.D) + ')'

class sum_square(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
B,C,H,W = input.size()
with torch.cuda.device_of(input):
xsum = input.new().resize_(C).zero_()
xsquare = input.new().resize_(C).zero_()
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_sum_square_Forward(
input.view(B,C,-1), xsum, xsquare)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_sum_square_Forward(
input.view(B,C,-1), xsum, xsquare)
else:
raise RuntimeError('unimplemented')
return xsum, xsquare

def backward(ctx, gradSum, gradSquare):
input, = ctx.saved_tensors
B,C,H,W = input.size()
with torch.cuda.device_of(input):
gradInput = input.new().resize_(B,C,H*W).zero_()
# gradSum.view(1,C,1,1).expand_as(input) + \
# 2*gradSquare.view(1,C,1,1).expand_as(input)*input
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_sum_square_Backward(
gradInput, input.view(B,C,-1), gradSum, gradSquare)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_sum_square_Backward(
gradInput, input.view(B,C,-1), gradSum, gradSquare)
else:
raise RuntimeError('unimplemented')
return gradInput.view(B,C,H,W)

class batchnormtrain(Function):
def forward(ctx, input, gamma, beta, mean, std):
ctx.save_for_backward(input, gamma, beta, mean, std)
assert(input.dim()==3)
with torch.cuda.device_of(input):
invstd = 1.0 / std
output = input.new().resize_as_(input)
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
else:
raise RuntimeError('unimplemented')
return output

def backward(ctx, gradOutput):
input, gamma, beta, mean, std = ctx.saved_tensors
invstd = 1.0 / std
with torch.cuda.device_of(input):
gradInput = gradOutput.new().resize_as_(input).zero_()
gradGamma = gradOutput.new().resize_as_(gamma).zero_()
gradBeta = gradOutput.new().resize_as_(beta).zero_()
gradMean = gradOutput.new().resize_as_(mean).zero_()
gradStd = gradOutput.new().resize_as_(std).zero_()

if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
True)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
True)
else:
raise RuntimeError('unimplemented')
return gradInput, gradGamma, gradBeta, gradMean, gradStd

class batchnormeval(Function):
def forward(ctx, input, gamma, beta, mean, std):
ctx.save_for_backward(input, gamma, beta, mean, std)
assert(input.dim()==3)
with torch.cuda.device_of(input):
invstd = 1.0 / std
output = input.new().resize_as_(input)
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
else:
raise RuntimeError('unimplemented')
return output

def backward(ctx, gradOutput):
input, gamma, beta, mean, std = ctx.saved_tensors
invstd = 1.0 / std
with torch.cuda.device_of(input):
gradInput = gradOutput.new().resize_as_(input).zero_()
gradGamma = gradOutput.new().resize_as_(gamma).zero_()
gradBeta = gradOutput.new().resize_as_(beta).zero_()
gradMean = gradOutput.new().resize_as_(mean).zero_()
gradStd = gradOutput.new().resize_as_(std).zero_()
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
False)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
False)
else:
raise RuntimeError('unimplemented')
return gradInput, gradGamma, gradBeta, gradMean, gradStd

Empty file added encoding/_ext/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions encoding/_ext/encoding_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

from torch.utils.ffi import _wrap_function
from ._encoding_lib import lib as _lib, ffi as _ffi

__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
__all__.append(symbol)

_import_symbols(locals())
18 changes: 18 additions & 0 deletions encoding/kernel/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// The maximum number of threads in a block
const int WARP_SIZE = 32;
const int MAX_BLOCK_SIZE = 512;

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}

__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
8 changes: 4 additions & 4 deletions encoding/kernel/generic/device_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#define THC_GENERIC_FILE "generic/device_tensor.h"
#else
template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
THCDeviceTensor<real, Dim> devicetensor(THCState *state, THCTensor *t) {
if (!t) {
return THCDeviceTensor<float, Dim>();
return THCDeviceTensor<real, Dim>();
}
int inDim = THCTensor_(nDimension)(state, t);
if (inDim == Dim) {
return toDeviceTensor<float, Dim>(state, t);
return toDeviceTensor<real, Dim>(state, t);
}
// View in which the last dimensions are collapsed or expanded as needed
THAssert(THCTensor_(isContiguous)(state, t));
Expand All @@ -32,6 +32,6 @@ THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
size[Dim - 1] *= t->size[i];
}
}
return THCDeviceTensor<float, Dim>(THCTensor_(data)(state, t), size);
return THCDeviceTensor<real, Dim>(THCTensor_(data)(state, t), size);
}
#endif
Loading

0 comments on commit fa0e478

Please sign in to comment.