Skip to content

Commit

Permalink
aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 15, 2017
1 parent c05c2a5 commit a3c3d94
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 38 deletions.
12 changes: 8 additions & 4 deletions encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
class aggregate(Function):
def forward(self, A, R):
# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
self.save_for_backward(A, R)
B, N, K, D = R.size()
E = A.new(B,K,D)
# TODO support cpu backend
print(encoding_lib)
encoding_lib.Encoding_Float_aggregate_forward(E, A, R)
return E

def backward(self, E):
# TODO FIXME this is test only
return E
def backward(self, gradE):
A, R = self.saved_tensors
gradA = A.clone()
gradR = R.clone()
encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE,
A, R)
return gradA, gradR


class Aggregate(Module):
Expand Down
37 changes: 37 additions & 0 deletions encoding/kernel/generic/device_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
* Created by: Hang Zhang
* ECE Department, Rutgers University
* Email: zhang.hang@rutgers.edu
* Copyright (c) 2017
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/device_tensor.h"
#else
template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
if (!t) {
return THCDeviceTensor<float, Dim>();
}
int inDim = THCTensor_(nDimension)(state, t);
if (inDim == Dim) {
return toDeviceTensor<float, Dim>(state, t);
}
// View in which the last dimensions are collapsed or expanded as needed
THAssert(THCTensor_(isContiguous)(state, t));
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = t->size[i];
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= t->size[i];
}
}
return THCDeviceTensor<float, Dim>(THCTensor_(data)(state, t), size);
}
#endif
47 changes: 29 additions & 18 deletions encoding/kernel/generic/encoding_kernel.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ __global__ void Encoding_(Aggregate_Forward_kernel) (
THCDeviceTensor<real, 3> A,
THCDeviceTensor<real, 4> R)
/*
* aggregating kernel function
* aggregating forward kernel function
*/
{
/* declarations of the variables */
Expand All @@ -41,7 +41,7 @@ __global__ void Encoding_(Aggregate_Forward_kernel) (
void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_,
THCTensor *A_, THCTensor *R_)
/*
* aggregating the residuals with assignment weights
* aggregating forward the residuals with assignment weights
*/
{
/* Check the GPU index and tensor dims*/
Expand All @@ -63,55 +63,66 @@ void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_,
THCudaCheck(cudaGetLastError());
}

/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
__global__ void Encoding_(Aggregate_Backward_kernel) (
THCDeviceTensor<real, 3> G,
THCDeviceTensor<real, 3> GA,
THCDeviceTensor<real, 4> GR,
THCDeviceTensor<real, 3> L,
THCDeviceTensor<real, 3> A,
THCDeviceTensor<real, 4> R)
/*
* aggregating backward kernel function
* G (dl/dR), L (dl/dE), A
*/
{
/* declarations of the variables */
int b, k, d, i, D;
real sum;
/* Get the index and channels */
b = blockIdx.z;
k = blockIdx.x * blockDim.x + threadIdx.x;
i = blockIdx.y * blockDim.y + threadIdx.y;
k = blockIdx.x * blockDim.x + threadIdx.x;
D = L.getSize(2);
/* boundary check for output */
if (k >= G.getSize(2) || i >= G.getSize(1)) return;
/* boundary check for output G \in R^{BxNxKxD} */
if (k >= GR.getSize(2) || i >= GR.getSize(1)) return;
/* main operation */
sum = 0;
for(d=0; d<D; d++) {
//sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
GR[b][i][k][d] = L[b][k][d] * A[b][i][k];
sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
}
G[b][i][k] = sum;
GA[b][i][k] = sum;
}

void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *G_,
THCTensor *L_, THCTensor *R_)
void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_,
THCTensor *GR_, THCTensor *L_, THCTensor *A_, THCTensor *R_)
/*
* aggregate backward to assignment weights
* G (dl/dR), L (dl/dE), A
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_(checkGPU)(state, 3, G_, L_, R_);
if (THCTensor_(nDimension)(state, G_) != 3 ||
THCTensor_(nDimension)(state, L_) != 3 ||
THCTensor_(nDimension)(state, R_) != 4)
THCTensor_(checkGPU)(state, 5, GA_, GR_, L_, A_, R_);
if (THCTensor_(nDimension)(state, GA_) != 3 ||
THCTensor_(nDimension)(state, GR_) != 4 ||
THCTensor_(nDimension)(state, L_) != 3 ||
THCTensor_(nDimension)(state, A_) != 3 ||
THCTensor_(nDimension)(state, R_) != 4)
THError("Encoding: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 3> G = devicetensor<3>(state, G_);
THCDeviceTensor<real, 3> GA = devicetensor<3>(state, GA_);
THCDeviceTensor<real, 4> GR = devicetensor<4>(state, GR_);
THCDeviceTensor<real, 3> L = devicetensor<3>(state, L_);
THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_);
THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
/* kernel function */
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 threads(16, 16);
dim3 blocks(G.getSize(2)/16+1, G.getSize(1)/16+1,
G.getSize(0));
Encoding_(Aggregate_Backward_kernel)<<<blocks, threads, 0, stream>>>(G, L, R);
dim3 blocks(GA.getSize(2)/16+1, GA.getSize(1)/16+1,
GA.getSize(0));
Encoding_(Aggregate_Backward_kernel)<<<blocks, threads, 0, stream>>>(GA,
GR, L, A, R);
THCudaCheck(cudaGetLastError());
}

#endif
4 changes: 2 additions & 2 deletions encoding/kernel/generic/encoding_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_,
THCTensor *A_, THCTensor *R_);
void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *G_,
THCTensor *L_, THCTensor *R_);
void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_,
THCTensor *GR_, THCTensor *L_, THCTensor *A_, THCTensor *R_);
#endif
7 changes: 7 additions & 0 deletions encoding/make.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash

mkdir -p encoding/build && cd encoding/build
# compile and install
cmake ..
make install
cd ..
4 changes: 2 additions & 2 deletions encoding/src/encoding_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@

int Encoding_Float_aggregate_forward(THCudaTensor *E, THCudaTensor *A,
THCudaTensor *R);
int Encoding_Float_aggregate_backward(THCudaTensor *G, THCudaTensor *L,
THCudaTensor *R);
int Encoding_Float_aggregate_backward(THCudaTensor *GA, THCudaTensor *GR,
THCudaTensor *L, THCudaTensor *A, THCudaTensor *R);
10 changes: 5 additions & 5 deletions encoding/src/generic/encoding_generic.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A,
return 0;
}

int Encoding_(aggregate_backward)(THCudaTensor *E, THCudaTensor *A,
THCudaTensor *R)
int Encoding_(aggregate_backward)(THCudaTensor *GA, THCudaTensor *GR,
THCudaTensor *L, THCudaTensor *A, THCudaTensor *R)
/*
* Aggregate operation
* Aggregate backward operation to A
* G (dl/dR), L (dl/dE), A (assignments)
*/
{
Encoding_(Aggregate_Backward)(state, E, A, R);
Encoding_(Aggregate_Backward)(state, GA, GR, L, A, R);
/* C function return number of the outputs */
return 0;
}

#endif
16 changes: 9 additions & 7 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
import torch.nn as nn
from torch.autograd import Variable
from encoding import Aggregate
from torch.autograd import gradcheck

model = Aggregate()

# declare dims and variables
B, N, K, D = 1, 2, 3, 4
# TODO cpu test
A = Variable(torch.ones(B,N,K).cuda())
R = Variable(torch.ones(B,N,K,D).cuda())
A = Variable(torch.randn(B,N,K).cuda(), requires_grad=True)
R = Variable(torch.randn(B,N,K,D).cuda(), requires_grad=True)

# check Aggregate operation
test = gradcheck(Aggregate(),(A, R), eps=1e-4, atol=1e-3)
print('Gradcheck of Aggreate() returns ', test)


E = model(A, R)
print(E)

0 comments on commit a3c3d94

Please sign in to comment.