Skip to content

Commit

Permalink
ciar exp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 21, 2017
1 parent de33018 commit 1633f31
Show file tree
Hide file tree
Showing 14 changed files with 588 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
*.swp
*.pyc
build/
encoding/build/
data/
20 changes: 15 additions & 5 deletions encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def forward(self, A, R):

def backward(self, gradE):
A, R = self.saved_tensors
gradA = A.clone()
gradR = R.clone()
gradA = A.new().resize_as_(A)
gradR = R.new().resize_as_(R)
encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE,
A, R)
return gradA, gradR
Expand All @@ -36,6 +36,7 @@ class Aggregate(nn.Module):
def forward(self, A, R):
return aggregate()(A, R)


class Encoding(nn.Module):
def __init__(self, D, K):
super(Encoding, self).__init__()
Expand All @@ -47,13 +48,19 @@ def __init__(self, D, K):
self.reset_params()

def reset_params(self):
self.codewords.data.uniform_(0.0, 0.02)
self.scale.data.uniform_(0.0, 0.02)
std1 = 1./((self.K*self.D)**(1/2))
std2 = 1./((self.K)**(1/2))
self.codewords.data.uniform_(-std1, std1)
self.scale.data.uniform_(-std2, std2)

def forward(self, X):
# input X is a 4D tensor
assert(X.dim()==4, "Encoding Layer requries 4D featuremaps!")
assert(X.size(1)==self.D,"Encoding Layer incompatible input channels!")
unpacked = False
if X.dim() == 3:
unpacked = True
X = X.unsqueeze(0)

B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
# reshape input
X = X.view(B,D,-1).transpose(1,2)
Expand All @@ -67,6 +74,9 @@ def forward(self, X):
A = self.softmax(A.view(B*N,K)).view(B,N,K)
# aggregate
E = aggregate()(A, R)

if unpacked:
E = E.squeeze(0)
return E

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion encoding/kernel/generic/encoding_kernel.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ __global__ void Encoding_(Aggregate_Backward_kernel) (
sum = 0;
for(d=0; d<D; d++) {
//sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
GR[b][i][k][d] = L[b][k][d] * A[b][i][k];
GR[b][i][k][d] = L[b][k][d].ldg() * A[b][i][k].ldg();
sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
}
GA[b][i][k] = sum;
Expand Down
8 changes: 4 additions & 4 deletions encoding/src/generic/encoding_generic.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
#define THC_GENERIC_FILE "generic/encoding_generic.c"
#else

int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A,
THCudaTensor *R)
int Encoding_(aggregate_forward)(THCTensor *E, THCTensor *A,
THCTensor *R)
/*
* Aggregate operation
*/
Expand All @@ -23,8 +23,8 @@ int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A,
return 0;
}

int Encoding_(aggregate_backward)(THCudaTensor *GA, THCudaTensor *GR,
THCudaTensor *L, THCudaTensor *A, THCudaTensor *R)
int Encoding_(aggregate_backward)(THCTensor *GA, THCTensor *GR,
THCTensor *L, THCTensor *A, THCTensor *R)
/*
* Aggregate backward operation to A
* G (dl/dR), L (dl/dE), A (assignments)
Expand Down
Empty file added experiments/dataset/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions experiments/dataset/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch
import torchvision
import torchvision.transforms as transforms

class Dataloder():
def __init__(self, args):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)

kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}
trainloader = torch.utils.data.DataLoader(trainset, batch_size=
args.batch_size, shuffle=True, **kwargs)
testloader = torch.utils.data.DataLoader(testset, batch_size=
args.batch_size, shuffle=False, **kwargs)
self.trainloader = trainloader
self.testloader = testloader

def getloader(self):
return self.trainloader, self.testloader
131 changes: 131 additions & 0 deletions experiments/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from option import Options
from model.encodenet import Net
from utils import *

# global variable
best_pred = 0.0
acclist = []

def main():
# init the args
args = Options().parse()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
# init dataloader
if args.dataset == 'cifar':
from dataset.cifar import Dataloder
train_loader, test_loader = Dataloder(args).getloader()
else:
raise ValueError('Unknow dataset!')

model = Net()

if args.cuda:
model.cuda()

if args.resume is not None:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_pred = checkpoint['best_pred']
acclist = checkpoint['acclist']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no resume checkpoint found at '{}'".format(args.resume))

criterion = nn.CrossEntropyLoss()
# TODO make weight_decay oen of args
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=
args.momentum, weight_decay=1e-4)

def train(epoch):
model.train()
global best_pred
train_loss, correct, total = 0,0,0
adjust_learning_rate(optimizer, epoch, best_pred, args)
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

train_loss += loss.data[0]
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum()
total += target.size(0)
progress_bar(batch_idx, len(train_loader),
'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1),
100.*correct/total, correct, total))

def test(epoch):
model.eval()
global best_pred
global acclist
test_loss, correct, total = 0,0,0
acc = 0.0
is_best = False
# for data, target in test_loader:
for batch_idx, (data, target) in enumerate(test_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += criterion(output, target).data[0]
# get the index of the max log-probability
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum()
total += target.size(0)

acc = 100.*correct/total
progress_bar(batch_idx, len(test_loader),
'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (test_loss/(batch_idx+1),
acc, correct, total))
# save checkpoint
acclist += [acc]
if acc > best_pred:
best_pred = acc
is_best = True
save_checkpoint({
'epoch': epoch,
'state_dict': model.state_dict(),
'best_pred': best_pred,
'acclist':acclist,
}, args=args, is_best=is_best)

# TODO add plot curve

for epoch in range(args.start_epoch, args.epochs + 1):
train(epoch)
# FIXME this is a bug somewhere not in the code
test(epoch)


if __name__ == "__main__":
main()
Empty file added experiments/model/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions experiments/model/encodenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch
import torch.nn as nn
import model.mynn as nn2
from encoding import Encoding

class Net(nn.Module):
def __init__(self, num_blocks=[2,2,2,2], num_classes=10,
block=nn2.Bottleneck):
super(Net, self).__init__()
if block == nn2.Basicblock:
self.expansion = 1
else:
self.expansion = 4

self.inplanes = 64
num_planes = [64, 128, 256, 512]
strides = [1, 2, 2, 2]
model = []
# Conv_1
model += [nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True)]
# Residual units
for i in range(4):
model += [self._residual_unit(block, num_planes[i], num_blocks[i],
strides[i])]
# Last conv layer
# TODO norm layer, instance norm?
model += [nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True),
Encoding(D=512*self.expansion,K=16),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True),
nn2.View(-1, 512*self.expansion*16),
nn.Linear(512*self.expansion*16, num_classes)]
self.model = nn.Sequential(*model)
print(model)

def _residual_unit(self, block, planes, n_blocks, stride):
strides = [stride] + [1]*(n_blocks-1)
layers = []
for i in range(n_blocks):
layers += [block(self.inplanes, planes, strides[i])]
self.inplanes = self.expansion*planes
return nn.Sequential(*layers)

def forward(self, input):
return self.model(input)
Loading

0 comments on commit 1633f31

Please sign in to comment.