Skip to content

Commit

Permalink
DualGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
qzq2514 committed Dec 27, 2019
1 parent 66eb21d commit ac65f0b
Show file tree
Hide file tree
Showing 12 changed files with 737 additions and 0 deletions.
50 changes: 50 additions & 0 deletions GANs_Advanced/DiscoGAN/tools/eval_DualGAN_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import tensorflow as tf
from DataLoader import Pix2Pix_loader
import cv2

model_path = 'models/ckpt/DualGAN_64_1227_new.ckpt-10000'
image_dir = "D:/forTensorflow/forGAN/edges2shoes/"
batch_size = 1
image_height = 64
image_width = 64

def eval():
with tf.Session() as sess:
ckpt_path = model_path
saver = tf.train.import_meta_graph(ckpt_path + '.meta')
saver.restore(sess, ckpt_path)

input_A_place = tf.get_default_graph().get_tensor_by_name('input_A:0')
input_B_place = tf.get_default_graph().get_tensor_by_name('input_B:0')
keep_prob_place = tf.get_default_graph().get_tensor_by_name('keep_prob:0')
is_training = tf.get_default_graph().get_tensor_by_name('is_training:0')

A2B_output = tf.get_default_graph().get_tensor_by_name('A2B_output:0')
B2A_output = tf.get_default_graph().get_tensor_by_name('B2A_output:0')

dataLoader = Pix2Pix_loader(image_dir, image_height, image_width, batch_size=batch_size)

index = 1
while True:
images_A, images_B = dataLoader.random_next_test_batch()

_A2B_output = sess.run(A2B_output, feed_dict={input_A_place: images_A,
is_training:False,keep_prob_place:0.5})
_A2B_output = (_A2B_output + 1) / 2 * 255.0
_A2B_output = _A2B_output.astype(np.uint8)

_B2A_output = sess.run(B2A_output, feed_dict={input_B_place: images_B,
is_training:False,keep_prob_place:0.5})
_B2A_output = (_B2A_output + 1) / 2 * 255.0
_B2A_output = _B2A_output.astype(np.uint8)

cv2.imshow("A", images_A[0])
cv2.imshow("B", images_B[0])
cv2.imshow("A2B_output",_A2B_output[0])
cv2.imshow("B2A_output", _B2A_output[0])

cv2.waitKey(0)

if __name__ == '__main__':
eval()
67 changes: 67 additions & 0 deletions GANs_Advanced/DualGAN_64/DataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from scipy import misc
import numpy as np
import os


class Pix2Pix_loader:
def __init__(self, root, image_height, image_width, batch_size, global_step=0):
self.train_fileList = self.get_filePath(os.path.join(root, "train"))
self.val_fileList = self.get_filePath(os.path.join(root, "val"))

self.image_height = image_height
self.image_width = image_width
self.batch_size = batch_size

self.sample_num = len(self.train_fileList)
self.test_sample_num = len(self.val_fileList)
self.step_per_epoch = self.sample_num // self.batch_size
self.batch_id = global_step % self.step_per_epoch

def get_filePath(self, dir):
file_paths = [os.path.join(dir, file_name) for file_name in os.listdir(dir)]
return np.array(file_paths)

def shuffle_train(self):
shuffle_ind = np.arange(0, self.sample_num)
np.random.shuffle(shuffle_ind)
self.train_fileList = self.train_fileList[shuffle_ind]

def next_batch_core(self, batch_fileList):
images_A = []
images_B = []
# 不要使用cv2.imread和cv2.resize不然读取和resize后的边缘图会产生轮廓断点的现象
# 严重影响最终图像生成效果
for file_path in batch_fileList:
image = misc.imread(file_path)
h, w = image.shape[:2]
image_A = misc.imresize(image[:, :w // 2, ...], [self.image_height, self.image_width])
image_B = misc.imresize(image[:, w // 2:, ...], [self.image_height, self.image_width])
images_A.append(image_A)
images_B.append(image_B)

images_A = np.array(images_A)
images_B = np.array(images_B)
np.random.shuffle(images_A)
np.random.shuffle(images_A)
return np.array(images_A), np.array(images_B)

def random_next_test_batch(self):
indices = np.random.choice(self.test_sample_num, self.batch_size)
batch_fileList = self.val_fileList[indices]
return self.next_batch_core(batch_fileList)

def random_next_train_batch(self):
indices = np.random.choice(self.sample_num, self.batch_size)
batch_fileList = self.train_fileList[indices]
return self.next_batch_core(batch_fileList)

def next_batch(self):
if self.batch_id >= self.step_per_epoch:
self.batch_id = 0
self.shuffle_train()
# print("batch_id:",self.batch_id)
bacth_ind_start = self.batch_id * self.batch_size
bacth_ind_end = (self.batch_id + 1) * self.batch_size
batch_fileList = self.train_fileList[bacth_ind_start:bacth_ind_end]
self.batch_id += 1
return self.next_batch_core(batch_fileList)
63 changes: 63 additions & 0 deletions GANs_Advanced/DualGAN_64/dbread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
from scipy import misc
from os import listdir
from os.path import isfile, join


class DBreader:
def __init__(self, filedir, batch_size, resize=0, labeled=True, color=True, shuffle=True):
self.color = color
self.labeled = labeled

self.batch_size = batch_size
tmp_filelist = [(filedir + '/' + f) for f in listdir(filedir) if isfile(join(filedir, f))]
tmp_filelist = np.array(tmp_filelist)

self.file_len = len(tmp_filelist)

self.filelist = []
self.labellist = []
if self.labeled:
for i in range(self.file_len):
splited = (tmp_filelist[i]).split(" ")
self.filelist.append(splited[0])
self.labellist.append(splited[1])
else:
self.filelist = tmp_filelist

self.batch_idx = 0
self.total_batch = int(self.file_len / batch_size)
self.idx_shuffled = np.arange(self.file_len)
if shuffle:
np.random.shuffle(self.idx_shuffled)
self.resize = resize

self.filelist = np.array(self.filelist)
self.labellist = np.array(self.labellist)

# Method for get the next batch
def next_batch(self):
if self.batch_idx == self.total_batch:
np.random.shuffle(self.idx_shuffled)
self.batch_idx = 0

batch = []
idx_set = self.idx_shuffled[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size]
batch_filelist = self.filelist[idx_set]

for i in range(self.batch_size):
im = misc.imread(batch_filelist[i])
if self.resize != 0:
im = misc.imresize(im, self.resize)
if self.color:
if im.shape[2] > 3:
im = im[:, :, 0:3]
batch.append(im)

if self.labeled:
label = self.labellist[idx_set]
self.batch_idx += 1
return np.array(batch).astype(np.float32), np.array(label).astype(np.int32)

self.batch_idx += 1
return np.array(batch).astype(np.float32)
150 changes: 150 additions & 0 deletions GANs_Advanced/DualGAN_64/net/DualGAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

class DualGAN:
def __init__(self,is_training,keep_prob,lambda_reconst):
self.is_training = is_training
self.epsilon = 1e-5
self.weight_decay = 0.00001
self.keep_prob = keep_prob
self.lambda_reconst = lambda_reconst

def preprocess(self,images,scale=False):
images = tf.to_float(images)
if scale:
images = tf.div(images, 127.5)
images = tf.subtract(images, 1.0)
return images

#[None,64,64,3]-->[None,64,64,3]
def generator(self,inputs,name_scope,reuse=False):
with tf.variable_scope(name_scope,reuse=reuse) as scope:
w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
with slim.arg_scope([slim.conv2d], padding="SAME", activation_fn=None, stride=2,kernel_size=[5,5],
weights_initializer=w_init,weights_regularizer=slim.l2_regularizer(self.weight_decay)):
with slim.arg_scope([slim.conv2d_transpose], padding="SAME", activation_fn=None, stride=2,kernel_size=[5,5],
weights_initializer=w_init,weights_regularizer=slim.l2_regularizer(self.weight_decay)):
# 使用updates_collections=None强制更新参数
with slim.arg_scope([slim.batch_norm], decay=0.9, epsilon=1e-5, scale=True,updates_collections=None,
activation_fn=None,is_training=self.is_training):
#Encode
e1 = slim.conv2d(inputs,64,activation_fn=None) #[None,32,32,64]
e2 = slim.conv2d(tf.nn.leaky_relu(e1), 64*2) #[None,16,16,128]
e2 = slim.batch_norm(e2)
e3 = slim.conv2d(tf.nn.leaky_relu(e2), 64*4) #[None,8,8,256]
e3 = slim.batch_norm(e3)
e4 = slim.conv2d(tf.nn.leaky_relu(e3), 64*8) #[None,4,4,512]
e4 = slim.batch_norm(e4)
e5 = slim.conv2d(tf.nn.leaky_relu(e4), 64*8) # [None,2,2,512]
e5 = slim.batch_norm(e5)
e6 = slim.conv2d(tf.nn.leaky_relu(e5), 64*8) # [None,1,1,512]
e6 = slim.batch_norm(e6)
# e7 = slim.conv2d(tf.nn.leaky_relu(e6), 64*8) # [None,1,1,512]
# e7 = slim.batch_norm(e7)

#Decode
d1 = slim.conv2d_transpose(tf.nn.relu(e6), 64 * 8)
d1 = tf.nn.dropout(slim.batch_norm(d1),self.keep_prob)
d1 = tf.concat([d1, e5],3) # [None,2,2,512*2]
d2 = slim.conv2d_transpose(tf.nn.relu(d1), 64 * 8)
d2 = tf.nn.dropout(slim.batch_norm(d2), self.keep_prob)
d2 = tf.concat([d2, e4], 3) # [None,4,4,512*2]
d3 = slim.conv2d_transpose(tf.nn.relu(d2), 64 * 8)
d3 = tf.nn.dropout(slim.batch_norm(d3), self.keep_prob)
d3 = tf.concat([d3, e3], 3) # [None,8,8,512*2]
d4 = slim.conv2d_transpose(tf.nn.relu(d3), 64 * 4)
d4 = tf.nn.dropout(slim.batch_norm(d4), self.keep_prob)
d4 = tf.concat([d4, e2], 3) # [None,16,16,256*2]
d5 = slim.conv2d_transpose(tf.nn.relu(d4), 64 * 2)
d5 = tf.nn.dropout(slim.batch_norm(d5), self.keep_prob)
d5 = tf.concat([d5, e1], 3) # [None,32,32,128*2]
# d6 = slim.conv2d_transpose(tf.nn.relu(d5), 64 * 2)
# d6 = tf.nn.dropout(slim.batch_norm(d6), self.keep_prob)
# d6 = tf.concat([d6, e1], 3)
d6 = slim.conv2d_transpose(tf.nn.relu(d5), 3) # [None,64,64,3]
generate_out = tf.nn.tanh(d6)
return generate_out

# [None,128,128,3]-->[None,1,1,1]
def discriminator(self,inputs,name_scope,reuse=False):
with tf.variable_scope(name_scope,reuse=reuse) as scope:
w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
with slim.arg_scope([slim.conv2d], padding="SAME", activation_fn=None, stride=2,kernel_size=[5,5],
weights_initializer=w_init,weights_regularizer=slim.l2_regularizer(self.weight_decay)):
# 使用updates_collections=None强制更新参数
with slim.arg_scope([slim.batch_norm], decay=0.9, epsilon=1e-5, scale=True,updates_collections=None,
activation_fn=tf.nn.leaky_relu, is_training=self.is_training):
feature1 = slim.conv2d(inputs, 64,activation_fn = tf.nn.leaky_relu) # [None,32,32,64]
feature2 = slim.conv2d(feature1, 128) # [None,16,16,128]
feature2 = slim.batch_norm(feature2)
feature3 = slim.conv2d(feature2, 256) # [None,8,8,256]
feature3 = slim.batch_norm(feature3)
feature4 = slim.conv2d(feature3, 512,stride=1) # [None,8,8,512]
feature4 = slim.batch_norm(feature4)
out_logits = slim.conv2d(feature4, 1,stride=1) # [None,8,8,1]
return out_logits

def get_vars(self):
all_vars = tf.trainable_variables()
g_vars = [var for var in all_vars if var.name.startswith("generator_")]
d_vars = [var for var in all_vars if var.name.startswith("discriminator_")]
return g_vars, d_vars

def build_DiscoGAN(self,input_A,input_B):
#归一化
input_A_pre = self.preprocess(input_A, scale=True)
input_B_pre = self.preprocess(input_B, scale=True)

#Domain A --> Domain B
AB = self.generator(input_A_pre,"generator_AB")
ABA = self.generator(AB,"generator_BA")

AB_logits = self.discriminator(AB, "discriminator_B")
B_logits = self.discriminator(input_B_pre, "discriminator_B", reuse=True)
reconst_A_loss = tf.reduce_mean(tf.square(ABA - input_A_pre))
fake_Gen_AB_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=AB_logits, labels=tf.ones_like(AB_logits)))
Gen_AB_loss = fake_Gen_AB_loss + self.lambda_reconst*reconst_A_loss
fake_Dis_AB_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=AB_logits, labels=tf.zeros_like(AB_logits)))
real_Dis_B_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=B_logits, labels=tf.ones_like(B_logits)))
Dis_B_loss = fake_Dis_AB_loss + real_Dis_B_loss

# Domain B --> Domain A
BA = self.generator(input_B_pre, "generator_BA",reuse=True)
BAB = self.generator(BA, "generator_AB",reuse=True)

BA_logits = self.discriminator(BA, "discriminator_A")
A_logits = self.discriminator(input_A_pre, "discriminator_A", reuse=True)
reconst_B_loss = tf.reduce_mean(tf.square(BAB - input_B_pre))
fake_Gen_BA_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=BA_logits, labels=tf.ones_like(BA_logits)))
Gen_BA_loss = fake_Gen_BA_loss + self.lambda_reconst*reconst_B_loss
fake_Dis_BA_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=BA_logits, labels=tf.zeros_like(BA_logits)))
real_Dis_A_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=A_logits, labels=tf.ones_like(A_logits)))
Dis_A_loss = fake_Dis_BA_loss + real_Dis_A_loss

Gen_loss = Gen_AB_loss + Gen_BA_loss
Dis_loss = Dis_B_loss + Dis_A_loss

return Gen_loss,Dis_loss

def sample_generate(self,input,type="A2B"):
if type=="A2B":
name_scope_first = "generator_AB"
name_scope_second = "generator_BA"
else:
name_scope_first = "generator_BA"
name_scope_second = "generator_AB"
input_pre = self.preprocess(input, scale=True)
generated_out = self.generator(input_pre,name_scope=name_scope_first,reuse=True)
reconst_image = self.generator(generated_out,name_scope=name_scope_second,reuse=True)

return generated_out,reconst_image



50 changes: 50 additions & 0 deletions GANs_Advanced/DualGAN_64/tools/ckpt2pb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util

MODEL_DIR = "models/pb"
MODEL_NAME = "frozen_model.pb"

if not tf.gfile.Exists(MODEL_DIR): # 创建目录
tf.gfile.MakeDirs(MODEL_DIR)

def freeze_graph(model_folder):
checkpoint = tf.train.get_checkpoint_state(model_folder) # 检查目录下ckpt文件状态是否可用
input_checkpoint = checkpoint.model_checkpoint_path # 得ckpt文件路径
output_graph = os.path.join(MODEL_DIR, MODEL_NAME) # PB模型保存路径

output_node_names = "A2B_output,B2A_output" # 原模型输出操作节点的名字
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True) # 得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.

graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图

with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']

output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess,input_graph_def,output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开

with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点

# for op in graph.get_operations():
# print(op.name, op.values())

if __name__ == '__main__':
parser = argparse.ArgumentParser()
model_folder = "models/ckpt"
freeze_graph(model_folder)
Loading

0 comments on commit ac65f0b

Please sign in to comment.