Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
twtygqyy committed Apr 19, 2017
0 parents commit 180f4eb
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 0 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# PyTorch VDSR
Implementation of Paper: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"(https://arxiv.org/abs/1609.04802) in PyTorch

## Usage
### Training
```
usage: main.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
[--step STEP] [--cuda] [--resume RESUME]
[--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS]
[--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY]
[--pretrained PRETRAINED]
optional arguments:
-h, --help show this help message and exit
--batchSize BATCHSIZE
training batch size
--nEpochs NEPOCHS number of epochs to train for
--lr LR Learning Rate. Default=1e-4
--step STEP Sets the learning rate to the initial LR decayed by
momentum every n epochs, Default: n=200
--cuda Use cuda?
--resume RESUME Path to checkpoint (default: none)
--start-epoch START_EPOCH
Manual epoch number (useful on restarts)
--clip CLIP Clipping Gradients. Default=0.1
--threads THREADS Number of threads for data loader to use, Default: 1
--momentum MOMENTUM Momentum, Default: 0.9
--weight-decay WEIGHT_DECAY, --wd WEIGHT_DECAY
weight decay, Default: 0
--pretrained PRETRAINED
path to pretrained model (default: none)
```

### Todo

Code for testing

Code for data generation

Performance Evalution
16 changes: 16 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch.utils.data as data
import torch
import h5py

class DatasetFromHdf5(data.Dataset):
def __init__(self, file_path):
super(DatasetFromHdf5, self).__init__()
hf = h5py.File(file_path)
self.data = hf.get("data")
self.target = hf.get("label")

def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()

def __len__(self):
return self.data.shape[0]
145 changes: 145 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import argparse, os
import torch
import math, random
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from srresnet import Net
from dataset import DatasetFromHdf5

# Training settings
parser = argparse.ArgumentParser(description="PyTorch SRResNet")
parser.add_argument("--batchSize", type=int, default=16, help="training batch size")
parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4")
parser.add_argument("--step", type=int, default=200, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=200")
parser.add_argument("--cuda", action="store_true", help="Use cuda?")
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
parser.add_argument("--clip", type=float, default=0.1, help="Clipping Gradients. Default=0.1")
parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1")
parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
parser.add_argument("--weight-decay", "--wd", default=0, type=float, help="weight decay, Default: 0")
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")

def main():

global opt, model
opt = parser.parse_args()
print(opt)

cuda = opt.cuda
if cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")

opt.seed = random.randint(1, 10000)
print("Random Seed: ", opt.seed)
torch.manual_seed(opt.seed)
if cuda:
torch.cuda.manual_seed(opt.seed)

cudnn.benchmark = True

print("===> Loading datasets")
train_set = DatasetFromHdf5("rgb_24in_96out_coco_scale4.h5")
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

print("===> Building model")
model = Net()
criterion = nn.MSELoss(size_average=False)

print("===> Setting GPU")
if cuda:
model = model.cuda()
criterion = criterion.cuda()

# optionally resume from a checkpoint
if opt.resume:
if os.path.isfile(opt.resume):
print("=> loading checkpoint '{}'".format(opt.resume))
checkpoint = torch.load(opt.resume)
opt.start_epoch = checkpoint["epoch"] + 1
model.load_state_dict(checkpoint["model"].state_dict())
else:
print("=> no checkpoint found at '{}'".format(opt.resume))

# optionally copy weights from a checkpoint
if opt.pretrained:
if os.path.isfile(opt.pretrained):
print("=> loading model '{}'".format(opt.pretrained))
weights = torch.load(opt.pretrained)
model.load_state_dict(weights['model'].state_dict())
else:
print("=> no model found at '{}'".format(opt.pretrained))

print("===> Setting Optimizer")
#optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, betas = (0.9, 0.999), eps=1e-08)
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
#optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)

print("===> Training")
for epoch in range(opt.start_epoch, opt.nEpochs + 1):
train(training_data_loader, optimizer, model, criterion, epoch)
save_checkpoint(model, epoch)

def total_gradient(parameters):
"""Computes a gradient clipping coefficient based on gradient norm."""
parameters = list(filter(lambda p: p.grad is not None, parameters))
totalnorm = 0
for p in parameters:
modulenorm = p.grad.data.norm()
totalnorm += modulenorm ** 2
totalnorm = totalnorm ** (1./2)
return totalnorm

def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
lr = opt.lr * (0.1 ** (epoch // opt.step))
return lr

def train(training_data_loader, optimizer, model, criterion, epoch):

lr = adjust_learning_rate(optimizer, epoch-1)

for param_group in optimizer.param_groups:
param_group["lr"] = lr

print "epoch =", epoch,"lr =",optimizer.param_groups[0]["lr"]
model.train()

for iteration, batch in enumerate(training_data_loader, 1):

input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)

if opt.cuda:
input = input.cuda()
target = target.cuda()

loss = criterion(model(input), target)

optimizer.zero_grad()

loss.backward()

#nn.utils.clip_grad_norm(model.parameters(),opt.clip)

optimizer.step()

if iteration%100 == 0:
print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
print "total gradient", total_gradient(model.parameters())

def save_checkpoint(model, epoch):
model_out_path = "model/" + "model_epoch_{}.pth".format(epoch)
state = {"epoch": epoch ,"model": model}
if not os.path.exists("model/"):
os.makedirs("model/")

torch.save(state, model_out_path)

print("Checkpoint saved to {}".format(model_out_path))

if __name__ == "__main__":
main()
65 changes: 65 additions & 0 deletions srresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
import math

class _Residual_Block(nn.Module):
def __init__(self):
super(_Residual_Block, self).__init__()

self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)

def forward(self, x):
identity_data = x
output = self.relu(self.bn1(self.conv1(x)))
output = self.bn2(self.conv2(output))
output = torch.add(output,identity_data)
return output

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)

self.residual = self.make_layer(_Residual_Block, 15)

self.upscale4x = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.ReLU(inplace=True),
)

self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def make_layer(self, block, num_of_layer):
layers = []
for _ in range(num_of_layer):
layers.append(block())
return nn.Sequential(*layers)

def forward(self, x):
out = self.relu(self.conv_input(x))
residual = out
out = self.residual(out)
out = torch.add(out,residual)
out = self.upscale4x(out)
out = self.conv_output(out)
return out

0 comments on commit 180f4eb

Please sign in to comment.