-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
737 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from DataLoader import Pix2Pix_loader | ||
import cv2 | ||
|
||
model_path = 'models/ckpt/DualGAN_64_1227_new.ckpt-10000' | ||
image_dir = "D:/forTensorflow/forGAN/edges2shoes/" | ||
batch_size = 1 | ||
image_height = 64 | ||
image_width = 64 | ||
|
||
def eval(): | ||
with tf.Session() as sess: | ||
ckpt_path = model_path | ||
saver = tf.train.import_meta_graph(ckpt_path + '.meta') | ||
saver.restore(sess, ckpt_path) | ||
|
||
input_A_place = tf.get_default_graph().get_tensor_by_name('input_A:0') | ||
input_B_place = tf.get_default_graph().get_tensor_by_name('input_B:0') | ||
keep_prob_place = tf.get_default_graph().get_tensor_by_name('keep_prob:0') | ||
is_training = tf.get_default_graph().get_tensor_by_name('is_training:0') | ||
|
||
A2B_output = tf.get_default_graph().get_tensor_by_name('A2B_output:0') | ||
B2A_output = tf.get_default_graph().get_tensor_by_name('B2A_output:0') | ||
|
||
dataLoader = Pix2Pix_loader(image_dir, image_height, image_width, batch_size=batch_size) | ||
|
||
index = 1 | ||
while True: | ||
images_A, images_B = dataLoader.random_next_test_batch() | ||
|
||
_A2B_output = sess.run(A2B_output, feed_dict={input_A_place: images_A, | ||
is_training:False,keep_prob_place:0.5}) | ||
_A2B_output = (_A2B_output + 1) / 2 * 255.0 | ||
_A2B_output = _A2B_output.astype(np.uint8) | ||
|
||
_B2A_output = sess.run(B2A_output, feed_dict={input_B_place: images_B, | ||
is_training:False,keep_prob_place:0.5}) | ||
_B2A_output = (_B2A_output + 1) / 2 * 255.0 | ||
_B2A_output = _B2A_output.astype(np.uint8) | ||
|
||
cv2.imshow("A", images_A[0]) | ||
cv2.imshow("B", images_B[0]) | ||
cv2.imshow("A2B_output",_A2B_output[0]) | ||
cv2.imshow("B2A_output", _B2A_output[0]) | ||
|
||
cv2.waitKey(0) | ||
|
||
if __name__ == '__main__': | ||
eval() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from scipy import misc | ||
import numpy as np | ||
import os | ||
|
||
|
||
class Pix2Pix_loader: | ||
def __init__(self, root, image_height, image_width, batch_size, global_step=0): | ||
self.train_fileList = self.get_filePath(os.path.join(root, "train")) | ||
self.val_fileList = self.get_filePath(os.path.join(root, "val")) | ||
|
||
self.image_height = image_height | ||
self.image_width = image_width | ||
self.batch_size = batch_size | ||
|
||
self.sample_num = len(self.train_fileList) | ||
self.test_sample_num = len(self.val_fileList) | ||
self.step_per_epoch = self.sample_num // self.batch_size | ||
self.batch_id = global_step % self.step_per_epoch | ||
|
||
def get_filePath(self, dir): | ||
file_paths = [os.path.join(dir, file_name) for file_name in os.listdir(dir)] | ||
return np.array(file_paths) | ||
|
||
def shuffle_train(self): | ||
shuffle_ind = np.arange(0, self.sample_num) | ||
np.random.shuffle(shuffle_ind) | ||
self.train_fileList = self.train_fileList[shuffle_ind] | ||
|
||
def next_batch_core(self, batch_fileList): | ||
images_A = [] | ||
images_B = [] | ||
# 不要使用cv2.imread和cv2.resize不然读取和resize后的边缘图会产生轮廓断点的现象 | ||
# 严重影响最终图像生成效果 | ||
for file_path in batch_fileList: | ||
image = misc.imread(file_path) | ||
h, w = image.shape[:2] | ||
image_A = misc.imresize(image[:, :w // 2, ...], [self.image_height, self.image_width]) | ||
image_B = misc.imresize(image[:, w // 2:, ...], [self.image_height, self.image_width]) | ||
images_A.append(image_A) | ||
images_B.append(image_B) | ||
|
||
images_A = np.array(images_A) | ||
images_B = np.array(images_B) | ||
np.random.shuffle(images_A) | ||
np.random.shuffle(images_A) | ||
return np.array(images_A), np.array(images_B) | ||
|
||
def random_next_test_batch(self): | ||
indices = np.random.choice(self.test_sample_num, self.batch_size) | ||
batch_fileList = self.val_fileList[indices] | ||
return self.next_batch_core(batch_fileList) | ||
|
||
def random_next_train_batch(self): | ||
indices = np.random.choice(self.sample_num, self.batch_size) | ||
batch_fileList = self.train_fileList[indices] | ||
return self.next_batch_core(batch_fileList) | ||
|
||
def next_batch(self): | ||
if self.batch_id >= self.step_per_epoch: | ||
self.batch_id = 0 | ||
self.shuffle_train() | ||
# print("batch_id:",self.batch_id) | ||
bacth_ind_start = self.batch_id * self.batch_size | ||
bacth_ind_end = (self.batch_id + 1) * self.batch_size | ||
batch_fileList = self.train_fileList[bacth_ind_start:bacth_ind_end] | ||
self.batch_id += 1 | ||
return self.next_batch_core(batch_fileList) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import numpy as np | ||
from scipy import misc | ||
from os import listdir | ||
from os.path import isfile, join | ||
|
||
|
||
class DBreader: | ||
def __init__(self, filedir, batch_size, resize=0, labeled=True, color=True, shuffle=True): | ||
self.color = color | ||
self.labeled = labeled | ||
|
||
self.batch_size = batch_size | ||
tmp_filelist = [(filedir + '/' + f) for f in listdir(filedir) if isfile(join(filedir, f))] | ||
tmp_filelist = np.array(tmp_filelist) | ||
|
||
self.file_len = len(tmp_filelist) | ||
|
||
self.filelist = [] | ||
self.labellist = [] | ||
if self.labeled: | ||
for i in range(self.file_len): | ||
splited = (tmp_filelist[i]).split(" ") | ||
self.filelist.append(splited[0]) | ||
self.labellist.append(splited[1]) | ||
else: | ||
self.filelist = tmp_filelist | ||
|
||
self.batch_idx = 0 | ||
self.total_batch = int(self.file_len / batch_size) | ||
self.idx_shuffled = np.arange(self.file_len) | ||
if shuffle: | ||
np.random.shuffle(self.idx_shuffled) | ||
self.resize = resize | ||
|
||
self.filelist = np.array(self.filelist) | ||
self.labellist = np.array(self.labellist) | ||
|
||
# Method for get the next batch | ||
def next_batch(self): | ||
if self.batch_idx == self.total_batch: | ||
np.random.shuffle(self.idx_shuffled) | ||
self.batch_idx = 0 | ||
|
||
batch = [] | ||
idx_set = self.idx_shuffled[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size] | ||
batch_filelist = self.filelist[idx_set] | ||
|
||
for i in range(self.batch_size): | ||
im = misc.imread(batch_filelist[i]) | ||
if self.resize != 0: | ||
im = misc.imresize(im, self.resize) | ||
if self.color: | ||
if im.shape[2] > 3: | ||
im = im[:, :, 0:3] | ||
batch.append(im) | ||
|
||
if self.labeled: | ||
label = self.labellist[idx_set] | ||
self.batch_idx += 1 | ||
return np.array(batch).astype(np.float32), np.array(label).astype(np.int32) | ||
|
||
self.batch_idx += 1 | ||
return np.array(batch).astype(np.float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import tensorflow as tf | ||
import tensorflow.contrib.slim as slim | ||
import numpy as np | ||
|
||
class DualGAN: | ||
def __init__(self,is_training,keep_prob,lambda_reconst): | ||
self.is_training = is_training | ||
self.epsilon = 1e-5 | ||
self.weight_decay = 0.00001 | ||
self.keep_prob = keep_prob | ||
self.lambda_reconst = lambda_reconst | ||
|
||
def preprocess(self,images,scale=False): | ||
images = tf.to_float(images) | ||
if scale: | ||
images = tf.div(images, 127.5) | ||
images = tf.subtract(images, 1.0) | ||
return images | ||
|
||
#[None,64,64,3]-->[None,64,64,3] | ||
def generator(self,inputs,name_scope,reuse=False): | ||
with tf.variable_scope(name_scope,reuse=reuse) as scope: | ||
w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) | ||
with slim.arg_scope([slim.conv2d], padding="SAME", activation_fn=None, stride=2,kernel_size=[5,5], | ||
weights_initializer=w_init,weights_regularizer=slim.l2_regularizer(self.weight_decay)): | ||
with slim.arg_scope([slim.conv2d_transpose], padding="SAME", activation_fn=None, stride=2,kernel_size=[5,5], | ||
weights_initializer=w_init,weights_regularizer=slim.l2_regularizer(self.weight_decay)): | ||
# 使用updates_collections=None强制更新参数 | ||
with slim.arg_scope([slim.batch_norm], decay=0.9, epsilon=1e-5, scale=True,updates_collections=None, | ||
activation_fn=None,is_training=self.is_training): | ||
#Encode | ||
e1 = slim.conv2d(inputs,64,activation_fn=None) #[None,32,32,64] | ||
e2 = slim.conv2d(tf.nn.leaky_relu(e1), 64*2) #[None,16,16,128] | ||
e2 = slim.batch_norm(e2) | ||
e3 = slim.conv2d(tf.nn.leaky_relu(e2), 64*4) #[None,8,8,256] | ||
e3 = slim.batch_norm(e3) | ||
e4 = slim.conv2d(tf.nn.leaky_relu(e3), 64*8) #[None,4,4,512] | ||
e4 = slim.batch_norm(e4) | ||
e5 = slim.conv2d(tf.nn.leaky_relu(e4), 64*8) # [None,2,2,512] | ||
e5 = slim.batch_norm(e5) | ||
e6 = slim.conv2d(tf.nn.leaky_relu(e5), 64*8) # [None,1,1,512] | ||
e6 = slim.batch_norm(e6) | ||
# e7 = slim.conv2d(tf.nn.leaky_relu(e6), 64*8) # [None,1,1,512] | ||
# e7 = slim.batch_norm(e7) | ||
|
||
#Decode | ||
d1 = slim.conv2d_transpose(tf.nn.relu(e6), 64 * 8) | ||
d1 = tf.nn.dropout(slim.batch_norm(d1),self.keep_prob) | ||
d1 = tf.concat([d1, e5],3) # [None,2,2,512*2] | ||
d2 = slim.conv2d_transpose(tf.nn.relu(d1), 64 * 8) | ||
d2 = tf.nn.dropout(slim.batch_norm(d2), self.keep_prob) | ||
d2 = tf.concat([d2, e4], 3) # [None,4,4,512*2] | ||
d3 = slim.conv2d_transpose(tf.nn.relu(d2), 64 * 8) | ||
d3 = tf.nn.dropout(slim.batch_norm(d3), self.keep_prob) | ||
d3 = tf.concat([d3, e3], 3) # [None,8,8,512*2] | ||
d4 = slim.conv2d_transpose(tf.nn.relu(d3), 64 * 4) | ||
d4 = tf.nn.dropout(slim.batch_norm(d4), self.keep_prob) | ||
d4 = tf.concat([d4, e2], 3) # [None,16,16,256*2] | ||
d5 = slim.conv2d_transpose(tf.nn.relu(d4), 64 * 2) | ||
d5 = tf.nn.dropout(slim.batch_norm(d5), self.keep_prob) | ||
d5 = tf.concat([d5, e1], 3) # [None,32,32,128*2] | ||
# d6 = slim.conv2d_transpose(tf.nn.relu(d5), 64 * 2) | ||
# d6 = tf.nn.dropout(slim.batch_norm(d6), self.keep_prob) | ||
# d6 = tf.concat([d6, e1], 3) | ||
d6 = slim.conv2d_transpose(tf.nn.relu(d5), 3) # [None,64,64,3] | ||
generate_out = tf.nn.tanh(d6) | ||
return generate_out | ||
|
||
# [None,128,128,3]-->[None,1,1,1] | ||
def discriminator(self,inputs,name_scope,reuse=False): | ||
with tf.variable_scope(name_scope,reuse=reuse) as scope: | ||
w_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) | ||
with slim.arg_scope([slim.conv2d], padding="SAME", activation_fn=None, stride=2,kernel_size=[5,5], | ||
weights_initializer=w_init,weights_regularizer=slim.l2_regularizer(self.weight_decay)): | ||
# 使用updates_collections=None强制更新参数 | ||
with slim.arg_scope([slim.batch_norm], decay=0.9, epsilon=1e-5, scale=True,updates_collections=None, | ||
activation_fn=tf.nn.leaky_relu, is_training=self.is_training): | ||
feature1 = slim.conv2d(inputs, 64,activation_fn = tf.nn.leaky_relu) # [None,32,32,64] | ||
feature2 = slim.conv2d(feature1, 128) # [None,16,16,128] | ||
feature2 = slim.batch_norm(feature2) | ||
feature3 = slim.conv2d(feature2, 256) # [None,8,8,256] | ||
feature3 = slim.batch_norm(feature3) | ||
feature4 = slim.conv2d(feature3, 512,stride=1) # [None,8,8,512] | ||
feature4 = slim.batch_norm(feature4) | ||
out_logits = slim.conv2d(feature4, 1,stride=1) # [None,8,8,1] | ||
return out_logits | ||
|
||
def get_vars(self): | ||
all_vars = tf.trainable_variables() | ||
g_vars = [var for var in all_vars if var.name.startswith("generator_")] | ||
d_vars = [var for var in all_vars if var.name.startswith("discriminator_")] | ||
return g_vars, d_vars | ||
|
||
def build_DiscoGAN(self,input_A,input_B): | ||
#归一化 | ||
input_A_pre = self.preprocess(input_A, scale=True) | ||
input_B_pre = self.preprocess(input_B, scale=True) | ||
|
||
#Domain A --> Domain B | ||
AB = self.generator(input_A_pre,"generator_AB") | ||
ABA = self.generator(AB,"generator_BA") | ||
|
||
AB_logits = self.discriminator(AB, "discriminator_B") | ||
B_logits = self.discriminator(input_B_pre, "discriminator_B", reuse=True) | ||
reconst_A_loss = tf.reduce_mean(tf.square(ABA - input_A_pre)) | ||
fake_Gen_AB_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=AB_logits, labels=tf.ones_like(AB_logits))) | ||
Gen_AB_loss = fake_Gen_AB_loss + self.lambda_reconst*reconst_A_loss | ||
fake_Dis_AB_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=AB_logits, labels=tf.zeros_like(AB_logits))) | ||
real_Dis_B_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=B_logits, labels=tf.ones_like(B_logits))) | ||
Dis_B_loss = fake_Dis_AB_loss + real_Dis_B_loss | ||
|
||
# Domain B --> Domain A | ||
BA = self.generator(input_B_pre, "generator_BA",reuse=True) | ||
BAB = self.generator(BA, "generator_AB",reuse=True) | ||
|
||
BA_logits = self.discriminator(BA, "discriminator_A") | ||
A_logits = self.discriminator(input_A_pre, "discriminator_A", reuse=True) | ||
reconst_B_loss = tf.reduce_mean(tf.square(BAB - input_B_pre)) | ||
fake_Gen_BA_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=BA_logits, labels=tf.ones_like(BA_logits))) | ||
Gen_BA_loss = fake_Gen_BA_loss + self.lambda_reconst*reconst_B_loss | ||
fake_Dis_BA_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=BA_logits, labels=tf.zeros_like(BA_logits))) | ||
real_Dis_A_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=A_logits, labels=tf.ones_like(A_logits))) | ||
Dis_A_loss = fake_Dis_BA_loss + real_Dis_A_loss | ||
|
||
Gen_loss = Gen_AB_loss + Gen_BA_loss | ||
Dis_loss = Dis_B_loss + Dis_A_loss | ||
|
||
return Gen_loss,Dis_loss | ||
|
||
def sample_generate(self,input,type="A2B"): | ||
if type=="A2B": | ||
name_scope_first = "generator_AB" | ||
name_scope_second = "generator_BA" | ||
else: | ||
name_scope_first = "generator_BA" | ||
name_scope_second = "generator_AB" | ||
input_pre = self.preprocess(input, scale=True) | ||
generated_out = self.generator(input_pre,name_scope=name_scope_first,reuse=True) | ||
reconst_image = self.generator(generated_out,name_scope=name_scope_second,reuse=True) | ||
|
||
return generated_out,reconst_image | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
|
||
import tensorflow as tf | ||
import os.path | ||
import argparse | ||
from tensorflow.python.framework import graph_util | ||
|
||
MODEL_DIR = "models/pb" | ||
MODEL_NAME = "frozen_model.pb" | ||
|
||
if not tf.gfile.Exists(MODEL_DIR): # 创建目录 | ||
tf.gfile.MakeDirs(MODEL_DIR) | ||
|
||
def freeze_graph(model_folder): | ||
checkpoint = tf.train.get_checkpoint_state(model_folder) # 检查目录下ckpt文件状态是否可用 | ||
input_checkpoint = checkpoint.model_checkpoint_path # 得ckpt文件路径 | ||
output_graph = os.path.join(MODEL_DIR, MODEL_NAME) # PB模型保存路径 | ||
|
||
output_node_names = "A2B_output,B2A_output" # 原模型输出操作节点的名字 | ||
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', | ||
clear_devices=True) # 得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import. | ||
|
||
graph = tf.get_default_graph() # 获得默认的图 | ||
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 | ||
|
||
with tf.Session() as sess: | ||
saver.restore(sess, input_checkpoint) # 恢复图并得到数据 | ||
for node in input_graph_def.node: | ||
if node.op == 'RefSwitch': | ||
node.op = 'Switch' | ||
for index in range(len(node.input)): | ||
if 'moving_' in node.input[index]: | ||
node.input[index] = node.input[index] + '/read' | ||
elif node.op == 'AssignSub': | ||
node.op = 'Sub' | ||
if 'use_locking' in node.attr: del node.attr['use_locking'] | ||
|
||
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 | ||
sess,input_graph_def,output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 | ||
|
||
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 | ||
f.write(output_graph_def.SerializeToString()) # 序列化输出 | ||
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点 | ||
|
||
# for op in graph.get_operations(): | ||
# print(op.name, op.values()) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
model_folder = "models/ckpt" | ||
freeze_graph(model_folder) |
Oops, something went wrong.