Skip to content

Commit

Permalink
ready to infer invaried length
Browse files Browse the repository at this point in the history
  • Loading branch information
acetylSv committed Mar 8, 2018
1 parent af3fc42 commit 1cc279f
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 46 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ cycle_gan_vc_log/
get_train_infer.py
test.py
get.sh
mag_part
dataset/
model_backup/
test_result/
99 changes: 57 additions & 42 deletions cycle_gan_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,40 @@ def __init__(self, mode="train"):
# Set GAN Loss Criterion (Defined in module.py)
self.criterion = mae_criterion

if mode=='test':
batch_size = None
if mode=='train':
# normalization term
self.A_mc_mean = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.mcep_dim))
self.B_mc_mean = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.mcep_dim))
self.A_mc_std = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.mcep_dim))
self.B_mc_std = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.mcep_dim))
self.A_logf0s_mean = tf.placeholder(tf.float32, shape=(hp.batch_size,1))
self.A_logf0s_std = tf.placeholder(tf.float32, shape=(hp.batch_size,1))
self.B_logf0s_mean = tf.placeholder(tf.float32, shape=(hp.batch_size,1))
self.B_logf0s_std = tf.placeholder(tf.float32, shape=(hp.batch_size,1))
# input
self.A_x = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.fix_seq_length, hp.mcep_dim))
self.A_f0 = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.fix_seq_length))
self.A_ap = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.fix_seq_length, 1+hp.n_fft//2))
self.B_x = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.fix_seq_length, hp.mcep_dim))
self.B_f0 = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.fix_seq_length))
self.B_ap = tf.placeholder(tf.float32, shape=(hp.batch_size, hp.fix_seq_length, 1+hp.n_fft//2))
else:
batch_size = hp.batch_size

# normalization term
self.A_mc_mean = tf.placeholder(tf.float32, shape=(batch_size, hp.mcep_dim))
self.B_mc_mean = tf.placeholder(tf.float32, shape=(batch_size, hp.mcep_dim))
self.A_mc_std = tf.placeholder(tf.float32, shape=(batch_size, hp.mcep_dim))
self.B_mc_std = tf.placeholder(tf.float32, shape=(batch_size, hp.mcep_dim))
self.A_logf0s_mean = tf.placeholder(tf.float32, shape=(batch_size,1))
self.A_logf0s_std = tf.placeholder(tf.float32, shape=(batch_size,1))
self.B_logf0s_mean = tf.placeholder(tf.float32, shape=(batch_size,1))
self.B_logf0s_std = tf.placeholder(tf.float32, shape=(batch_size,1))
# input
self.A_x = tf.placeholder(tf.float32, shape=(batch_size, hp.fix_seq_length, hp.mcep_dim))
self.A_f0 = tf.placeholder(tf.float32, shape=(batch_size, hp.fix_seq_length))
self.A_ap = tf.placeholder(tf.float32, shape=(batch_size, hp.fix_seq_length, 1+hp.n_fft//2))
self.B_x = tf.placeholder(tf.float32, shape=(batch_size, hp.fix_seq_length, hp.mcep_dim))
self.B_f0 = tf.placeholder(tf.float32, shape=(batch_size, hp.fix_seq_length))
self.B_ap = tf.placeholder(tf.float32, shape=(batch_size, hp.fix_seq_length, 1+hp.n_fft//2))
# normalization term
self.A_mc_mean = tf.placeholder(tf.float32, shape=(None, hp.mcep_dim))
self.B_mc_mean = tf.placeholder(tf.float32, shape=(None, hp.mcep_dim))
self.A_mc_std = tf.placeholder(tf.float32, shape=(None, hp.mcep_dim))
self.B_mc_std = tf.placeholder(tf.float32, shape=(None, hp.mcep_dim))
self.A_logf0s_mean = tf.placeholder(tf.float32, shape=(None))
self.A_logf0s_std = tf.placeholder(tf.float32, shape=(None))
self.B_logf0s_mean = tf.placeholder(tf.float32, shape=(None))
self.B_logf0s_std = tf.placeholder(tf.float32, shape=(None))
# input
self.A_x = tf.placeholder(tf.float32, shape=(1, None, hp.mcep_dim))
self.A_f0 = tf.placeholder(tf.float32, shape=(1, None))
self.A_ap = tf.placeholder(tf.float32, shape=(1, None, 1+hp.n_fft//2))
self.B_x = tf.placeholder(tf.float32, shape=(1, None, hp.mcep_dim))
self.B_f0 = tf.placeholder(tf.float32, shape=(1, None))
self.B_ap = tf.placeholder(tf.float32, shape=(1, None, 1+hp.n_fft//2))

# Domain-Transfering
with tf.variable_scope('gen_A_to_B'):
Expand All @@ -58,28 +71,18 @@ def __init__(self, mode="train"):
with tf.variable_scope('gen_B_to_A', reuse=True):
self.A_identity_y_hat = build_generator(self.A_x)

# Discriminator
with tf.variable_scope('dis_A') as scope:
self.v_A_real_logits, self.v_A_real = build_discriminator(self.A_x)
scope.reuse_variables()
self.v_A_fake_logits, self.v_A_fake = build_discriminator(self.A_y_hat)

with tf.variable_scope('dis_B') as scope:
self.v_B_real_logits, self.v_B_real = build_discriminator(self.B_x)
scope.reuse_variables()
self.v_B_fake_logits, self.v_B_fake = build_discriminator(self.B_y_hat)

self.gen_vars = [v for v in tf.trainable_variables() if v.name.startswith('gen_')]
self.dis_vars = [v for v in tf.trainable_variables() if v.name.startswith('dis_')]
'''
for v in self.dis_vars : print(v)
print('----------------------')
for v in self.gen_vars : print(v)
'''
'''
vs = [v for v in tf.trainable_variables()]
for v in vs : print(v)
'''
if mode is 'train':
# Discriminator
with tf.variable_scope('dis_A') as scope:
self.v_A_real_logits, self.v_A_real = build_discriminator(self.A_x)
scope.reuse_variables()
self.v_A_fake_logits, self.v_A_fake = build_discriminator(self.A_y_hat)

with tf.variable_scope('dis_B') as scope:
self.v_B_real_logits, self.v_B_real = build_discriminator(self.B_x)
scope.reuse_variables()
self.v_B_fake_logits, self.v_B_fake = build_discriminator(self.B_y_hat)


# monitor
self.audio_A = tf.py_func(MCEPs2wav, [self.A_x[0], self.A_f0[0], self.A_ap[0], \
Expand All @@ -104,6 +107,18 @@ def __init__(self, mode="train"):
[self.B_identity_y_hat[0], self.B_f0[0], self.B_ap[0], \
self.B_mc_mean, self.B_mc_std, self.B_logf0s_mean, self.B_logf0s_std], tf.float32)
if mode in ("train"):
# var collection
self.gen_vars = [v for v in tf.trainable_variables() if v.name.startswith('gen_')]
self.dis_vars = [v for v in tf.trainable_variables() if v.name.startswith('dis_')]
'''
for v in self.dis_vars : print(v)
print('----------------------')
for v in self.gen_vars : print(v)
'''
'''
vs = [v for v in tf.trainable_variables()]
for v in vs : print(v)
'''
# Loss
## Generator Loss
self.loss_gen_A = self.criterion(self.v_B_fake_logits, tf.ones_like(self.v_B_fake_logits))
Expand Down
11 changes: 9 additions & 2 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def get_partition(self, idss):
return np.array(all_mcs), np.array(all_f0s), np.array(all_aps), \
np.array(all_mc_mean), np.array(all_mc_std), \
np.array(all_logf0_mean), np.array(all_logf0_std)


def get_uttrs(self, A_ids, B_ids):
A_uttrs = [x for x in self.hf['test'][A_ids].keys()][:5]
B_uttrs = [x for x in self.hf['test'][B_ids].keys()][:5]
return A_uttrs, B_uttrs

def get_test_partition(self, ids, uttr_id):
mcs = self.hf['test'][ids][uttr_id]['normed_mc'][:]
f0s = self.hf['test'][ids][uttr_id]['normed_logf0'][:]
Expand All @@ -60,7 +65,9 @@ def get_test_partition(self, ids, uttr_id):
mc_std = self.hf['test'][ids][uttr_id]['mc_std'][:]
logf0_mean = self.hf['test'][ids][uttr_id]['logf0_mean'][:]
logf0_std = self.hf['test'][ids][uttr_id]['logf0_std'][:]


print(mcs.shape, f0s.shape, aps.shape, mc_mean.shape, mc_std.shape, logf0_mean.shape, logf0_std.shape)

mcs = np.expand_dims(np.reshape(mcs, [-1, hp.mcep_dim]), 0)
f0s = np.expand_dims(np.reshape(f0s, [-1]), 0)
aps = np.expand_dims(np.reshape(aps, [-1, 1+hp.n_fft//2]), 0)
Expand Down
133 changes: 133 additions & 0 deletions gan_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import sys, os
import tensorflow as tf
import numpy as np

from network import *
from hyperparams import Hyperparams as hp
from utils import *
from data_loader import *
from cycle_gan_graph import Graph

# init random_seed
#tf.set_random_seed(2401)
#np.random.seed(2401)
#random.seed(2401)

def test():
# Data loader
dl = Data_loader(mode='test')
# Build graph
g = Graph(mode='test'); print("Testing Graph loaded")
# Saver
saver = tf.train.Saver(max_to_keep = 5)
# Session
sess = tf.Session()
# If model exist, restore, else init a new one
ckpt = tf.train.get_checkpoint_state(hp.log_dir)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("=====Reading model parameters from %s=====" % ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
gs = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
else:
print("=====Error: model not found=====")
dl.close_hdf5()
sess.close()
return

# ALL DATA
#A_idss = ['226', '227', '232', '237']
#B_idss = ['225', '228', '229', '230']

# Test In-Domain
A_idss = ['226']
B_idss = ['225']
A_uttrs = ['335', '336', '337', '338', '339']
B_uttrs = ['330', '331', '332', '334', '335']

for A_uttr, B_uttr in zip(A_uttrs, B_uttrs):
A_normed_mcs, A_normed_logf0s, A_aps, \
A_mc_mean, A_mc_std, A_logf0_mean, A_logf0_std = \
dl.get_test_partition(A_idss[0], A_uttr)
B_normed_mcs, B_normed_logf0s, B_aps, \
B_mc_mean, B_mc_std, B_logf0_mean, B_logf0_std = \
dl.get_test_partition(B_idss[0], B_uttr)

#print(A_normed_mcs.shape, A_normed_logf0s.shape, A_aps.shape)
#print(A_mc_mean.shape, A_mc_std.shape, A_logf0_mean.shape, A_logf0_std.shape)

audio_A, audio_B, audio_A_to_B, audio_B_to_A, audio_A_to_B_to_A, audio_B_to_A_to_B = \
sess.run(
[g.audio_A, g.audio_B, \
g.audio_A_to_B, g.audio_B_to_A, \
g.audio_A_to_B_to_A, g.audio_B_to_A_to_B],
feed_dict={
g.A_x:A_normed_mcs, g.B_x:B_normed_mcs,
g.A_f0: A_normed_logf0s, g.B_f0: B_normed_logf0s,
g.A_ap: A_aps, g.B_ap: B_aps,
g.A_mc_mean: A_mc_mean, g.A_mc_std: A_mc_std,
g.B_mc_mean: B_mc_mean, g.B_mc_std: B_mc_std,
g.A_logf0s_mean: A_logf0_mean, g.A_logf0s_std: A_logf0_std,
g.B_logf0s_mean: B_logf0_mean, g.B_logf0s_std: B_logf0_std
}
)
librosa.output.write_wav('test_result/in_domain/test_A_{}.wav'.format(A_idss[0]+'_'+A_uttr), np.array(audio_A), hp.sr)
librosa.output.write_wav('test_result/in_domain/test_B_{}.wav'.format(B_idss[0]+'_'+B_uttr), np.array(audio_B), hp.sr)
librosa.output.write_wav('test_result/in_domain/test_A_to_B_{}.wav'.format(A_idss[0]+'_'+A_uttr),
np.array(audio_A_to_B), hp.sr)
librosa.output.write_wav('test_result/in_domain/test_B_to_A_{}.wav'.format(B_idss[0]+'_'+B_uttr),
np.array(audio_B_to_A), hp.sr)
librosa.output.write_wav('test_result/in_domain/test_A_to_B_to_A_{}.wav'.format(A_idss[0]+'_'+A_uttr),
np.array(audio_A_to_B_to_A), hp.sr)
librosa.output.write_wav('test_result/in_domain/test_B_to_A_to_B_{}.wav'.format(B_idss[0]+'_'+B_uttr),
np.array(audio_B_to_A_to_B), hp.sr)

# Test Out-of-Domain
A_idss = ['227', '232', '237']
B_idss = ['228', '229', '230']

for A_ids, B_ids in zip(A_idss, B_idss):
A_uttrs, B_uttrs = dl.get_uttrs(A_ids, B_ids)
for A_uttr, B_uttr in zip(A_uttrs, B_uttrs):
A_normed_mcs, A_normed_logf0s, A_aps, \
A_mc_mean, A_mc_std, A_logf0_mean, A_logf0_std = \
dl.get_test_partition(A_ids, A_uttr)
B_normed_mcs, B_normed_logf0s, B_aps, \
B_mc_mean, B_mc_std, B_logf0_mean, B_logf0_std = \
dl.get_test_partition(B_ids, B_uttr)

#print(A_normed_mcs.shape, A_normed_logf0s.shape, A_aps.shape)
#print(A_mc_mean.shape, A_mc_std.shape, A_logf0_mean.shape, A_logf0_std.shape)

audio_A, audio_B, audio_A_to_B, audio_B_to_A, audio_A_to_B_to_A, audio_B_to_A_to_B = \
sess.run(
[g.audio_A, g.audio_B, \
g.audio_A_to_B, g.audio_B_to_A, \
g.audio_A_to_B_to_A, g.audio_B_to_A_to_B],
feed_dict={
g.A_x:A_normed_mcs, g.B_x:B_normed_mcs,
g.A_f0: A_normed_logf0s, g.B_f0: B_normed_logf0s,
g.A_ap: A_aps, g.B_ap: B_aps,
g.A_mc_mean: A_mc_mean, g.A_mc_std: A_mc_std,
g.B_mc_mean: B_mc_mean, g.B_mc_std: B_mc_std,
g.A_logf0s_mean: A_logf0_mean, g.A_logf0s_std: A_logf0_std,
g.B_logf0s_mean: B_logf0_mean, g.B_logf0s_std: B_logf0_std
}
)
librosa.output.write_wav('test_result/out_domain/test_A_{}.wav'.format(A_idss[0]+'_'+A_uttr), np.array(audio_A), hp.sr)
librosa.output.write_wav('test_result/out_domain/test_B_{}.wav'.format(B_idss[0]+'_'+B_uttr), np.array(audio_B), hp.sr)
librosa.output.write_wav('test_result/out_domain/test_A_to_B_{}.wav'.format(A_idss[0]+'_'+A_uttr),
np.array(audio_A_to_B), hp.sr)
librosa.output.write_wav('test_result/out_domain/test_B_to_A_{}.wav'.format(B_idss[0]+'_'+B_uttr),
np.array(audio_B_to_A), hp.sr)
librosa.output.write_wav('test_result/out_domain/test_A_to_B_to_A_{}.wav'.format(A_idss[0]+'_'+A_uttr),
np.array(audio_A_to_B_to_A), hp.sr)
librosa.output.write_wav('test_result/out_domain/test_B_to_A_to_B_{}.wav'.format(B_idss[0]+'_'+B_uttr),
np.array(audio_B_to_A_to_B), hp.sr)

# exit
dl.close_hdf5()
sess.close()

if __name__ == '__main__':
test()
print('Infer Done')
2 changes: 2 additions & 0 deletions gan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def train():
B_mc_mean, B_mc_std, B_logf0_mean, B_logf0_std = dl.get_partition(B_idss)

while True:
if gs > 200000:
exit()
A_normed_mcs, A_normed_logf0s, A_aps, \
A_mc_mean, A_mc_std, A_logf0_mean, A_logf0_std = \
my_shuffle(A_normed_mcs, A_normed_logf0s, A_aps, \
Expand Down
2 changes: 1 addition & 1 deletion hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Hyperparams:
summary_period = 300
save_period = 500
LAMBDA_CYCLE = 10
LAMBDA_IDENTITY = 5
LAMBDA_IDENTITY = 0

# Signal Processing
sr = 16000
Expand Down
2 changes: 1 addition & 1 deletion network.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_generator(inputs):
return h12

def build_discriminator(inputs):
# inputs_shape: [batch, w=128, c=513]
# inputs_shape: [batch, w=128, c=26]
# inputs_reshape_shape: [batch, h=513, w=128, c=1]
inputs = tf.transpose(inputs, [0, 2, 1])
inputs_reshape = tf.expand_dims(inputs, [-1])
Expand Down

0 comments on commit 1cc279f

Please sign in to comment.