Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/NVlabs/SPADE
Browse files Browse the repository at this point in the history
  • Loading branch information
taesungp committed Apr 14, 2019
2 parents d687d04 + 8dfaea6 commit 44f15f1
Show file tree
Hide file tree
Showing 15 changed files with 213 additions and 195 deletions.
3 changes: 0 additions & 3 deletions data/pix2pix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def initialize(self, opt):

size = len(self.label_paths)
self.dataset_size = size
# if opt.isTrain:
# round_to_ngpus = (size // ngpus) * ngpus
# self.dataset_size = round_to_ngpus

def get_paths(self, opt):
label_paths = []
Expand Down
55 changes: 55 additions & 0 deletions datasets/coco_generate_instance_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import os
import argparse
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
from skimage.draw import polygon

parser = argparse.ArgumentParser()
parser.add_argument('--annotation_file', type=str, default="./annotations/instances_train2017.json",
help="Path to the annocation file. It can be downloaded at http://images.cocodataset.org/annotations/annotations_trainval2017.zip. Should be either instances_train2017.json or instances_val2017.json")
parser.add_argument('--input_label_dir', type=str, default="./train_label/",
help="Path to the directory containing label maps. It can be downloaded at http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip")
parser.add_argument('--output_instance_dir', type=str, default="./train_inst/",
help="Path to the output directory of instance maps")

opt = parser.parse_args()

print("annotation file at {}".format(opt.annotation_file))
print("input label maps at {}".format(opt.input_label_dir))
print("output dir at {}".format(opt.output_instance_dir))

# initialize COCO api for instance annotations
coco = COCO(opt.annotation_file)


# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
imgIds = coco.getImgIds(catIds=coco.getCatIds(cats))
for ix, id in enumerate(imgIds):
if ix % 50 == 0:
print("{} / {}".format(ix, len(imgIds)))
img_dict = coco.loadImgs(id)[0]
filename = img_dict["file_name"].replace("jpg", "png")
label_name = os.path.join(opt.input_label_dir, filename)
inst_name = os.path.join(opt.output_instance_dir, filename)
img = io.imread(label_name, as_grey=True)

annIds = coco.getAnnIds(imgIds=id, catIds=[], iscrowd=None)
anns = coco.loadAnns(annIds)
count = 0
for ann in anns:
if type(ann["segmentation"]) == list:
if "segmentation" in ann:
for seg in ann["segmentation"]:
poly = np.array(seg).reshape((int(len(seg) / 2), 2))
rr, cc = polygon(poly[:, 1] - 1, poly[:, 0] - 1)
img[rr, cc] = count
count += 1

io.imsave(inst_name, img)
17 changes: 7 additions & 10 deletions models/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,26 @@
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import importlib
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.loss import *
from models.networks.discriminator import *
from models.networks.generator import *
from models.networks.encoder import *
import util.util as util


def find_network_using_name(target_network_name, filename):
target_class_name = target_network_name + filename
module_name = 'models.networks.' + filename
network = util.find_class_in_module(target_class_name, module_name)

assert issubclass(network, BaseNetwork), \
"Class %s should be a subclass of BaseNetwork" % network

return network


def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()

Expand All @@ -37,18 +33,20 @@ def modify_commandline_options(parser, is_train):
parser = netD_cls.modify_commandline_options(parser, is_train)
netE_cls = find_network_using_name('conv', 'encoder')
parser = netE_cls.modify_commandline_options(parser, is_train)

return parser


def create_network(cls, opt):
net = cls(opt)
net.print_network()
if len(opt.gpu_ids) > 0:
assert(torch.cuda.is_available())
assert(torch.cuda.is_available())
net.cuda()
net.init_weights(opt.init_type, opt.init_variance)
return net


def define_G(opt):
netG_cls = find_network_using_name(opt.netG, 'generator')
return create_network(netG_cls, opt)
Expand All @@ -63,4 +61,3 @@ def define_E(opt):
# there exists only one encoder type
netE_cls = find_network_using_name('conv', 'encoder')
return create_network(netE_cls, opt)

37 changes: 13 additions & 24 deletions models/networks/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,28 @@
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import math
import re
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
import torch.nn.utils.spectral_norm as spectral_norm
from models.networks.base_network import BaseNetwork
from models.networks.normalization import SPADE


## ResNet block that uses SPADE.
## It differs from the ResNet block of pix2pixHD in that
## it takes in the segmentation map as input, learns the skip connection if necessary,
## and applies normalization first and then convolution.
## This architecture seemed like a standard architecture for unconditional or
## class-conditional GAN architecture using residual block.
## The code was inspired from https://github.com/LMescheder/GAN_stability.
# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)


# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
Expand All @@ -44,7 +37,7 @@ def __init__(self, fin, fout, opt):
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)

# define normalization layers
spade_config_str = opt.norm_G.replace('spectral', '')
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc)
Expand Down Expand Up @@ -89,15 +82,14 @@ def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size))
)


def forward(self, x):
y = self.conv_block(x)
out = x + y
return out


## VGG architecter, used for the perceptual loss using a pretrained VGG network
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
Expand All @@ -123,12 +115,9 @@ def __init__(self, requires_grad=False):

def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out



5 changes: 2 additions & 3 deletions models/networks/base_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn
from torch.nn import init


class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
Expand Down Expand Up @@ -44,7 +44,7 @@ def init_func(m):
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
Expand All @@ -57,4 +57,3 @@ def init_func(m):
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)

30 changes: 10 additions & 20 deletions models/networks/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import sys
import torch
import re
import torch.nn as nn
from collections import OrderedDict
import os.path
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
Expand All @@ -31,7 +24,7 @@ def modify_commandline_options(parser, is_train):
subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator',
'models.networks.discriminator')
subnetD.modify_commandline_options(parser, is_train)

return parser

def __init__(self, opt):
Expand All @@ -55,8 +48,8 @@ def downsample(self, input):
stride=2, padding=[1, 1],
count_include_pad=False)

## Returns list of lists of discriminator outputs.
## The final result is of size opt.num_D x opt.n_layers_D
# Returns list of lists of discriminator outputs.
# The final result is of size opt.num_D x opt.n_layers_D
def forward(self, input):
result = []
get_intermediate_features = not self.opt.no_ganFeat_loss
Expand All @@ -69,7 +62,7 @@ def forward(self, input):

return result


# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(BaseNetwork):
@staticmethod
Expand All @@ -83,10 +76,10 @@ def __init__(self, opt):
self.opt = opt

kw = 4
padw = int(np.ceil((kw-1.0)/2))
padw = int(np.ceil((kw - 1.0) / 2))
nf = opt.ndf
input_nc = self.compute_D_input_nc(opt)

norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, False)]]
Expand All @@ -97,14 +90,14 @@ def __init__(self, opt):
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
stride=2, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
]]

sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]

## We divide the layers into groups to extract intermediate layer outputs
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model'+str(n), nn.Sequential(*sequence[n]))
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))

def compute_D_input_nc(self, opt):
input_nc = opt.label_nc + opt.output_nc
if opt.contain_dontcare_label:
Expand All @@ -124,6 +117,3 @@ def forward(self, input):
return results[1:]
else:
return results[-1]



14 changes: 6 additions & 8 deletions models/networks/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer


class ConvEncoder(BaseNetwork):
""" Same architecture as the image discriminator """

def __init__(self, opt):
super().__init__()

kw = 3
pw = int(np.ceil((kw-1.0)/2))
pw = int(np.ceil((kw - 1.0) / 2))
ndf = opt.ngf
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
Expand All @@ -30,15 +29,15 @@ def __init__(self, opt):
self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))

self.so = s0 = 4
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)

self.actvn = nn.LeakyReLU(0.2, False)

def forward(self, x):
if x.size(2) != 256 or x.size(3) != 256:
x = F.interpolate(x, size=(256, 256), mode='bilinear')

x = self.layer1(x)
x = self.layer2(self.actvn(x))
x = self.layer3(self.actvn(x))
Expand All @@ -53,4 +52,3 @@ def forward(self, x):
logvar = self.fc_var(x)

return mu, logvar

Loading

0 comments on commit 44f15f1

Please sign in to comment.