Skip to content

Commit

Permalink
Wrapping up things, comments
Browse files Browse the repository at this point in the history
  • Loading branch information
paarthneekhara committed Aug 25, 2016
1 parent 2d15fd2 commit 7e61f91
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ eval_trec.py
theanotest.py
data_loader_old.py
Utils/word_embeddings_old.py
gen_backup.py
gen_backup.py
data_loader_test.py
3 changes: 0 additions & 3 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def save_caption_vectors_flowers(data_dir):
h.create_dataset(key, data=encoded_captions[key])
h.close()




def main():
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, default='train',
Expand Down
13 changes: 9 additions & 4 deletions generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main():
parser.add_argument('--data_dir', type=str, default="Data",
help='Data Directory')

parser.add_argument('--model_path', type=str, default='Data/Models/model_flowers_temp.ckpt',
parser.add_argument('--model_path', type=str, default='Data/Models/latest_model_flowers_temp.ckpt',
help='Trained Model Path')

parser.add_argument('--n_images', type=int, default=5,
Expand Down Expand Up @@ -89,12 +89,17 @@ def main():

for f in os.listdir( join(args.data_dir, 'val_samples')):
if os.path.isfile(f):
os.unlink(f)
os.unlink(join(args.data_dir, 'val_samples/' + f))

for cn in range(0, len(caption_vectors)):
caption_images = []
for i, im in enumerate( caption_image_dic[ cn ] ):
im_name = "caption_{}_{}.jpg".format(cn, i)
scipy.misc.imsave( join(args.data_dir, 'val_samples/{}'.format(im_name)) , im)
# im_name = "caption_{}_{}.jpg".format(cn, i)
# scipy.misc.imsave( join(args.data_dir, 'val_samples/{}'.format(im_name)) , im)
caption_images.append( im )
caption_images.append( np.zeros((64, 5, 3)) )
combined_image = np.concatenate( caption_images[0:-1], axis = 1 )
scipy.misc.imsave( join(args.data_dir, 'val_samples/combined_image_{}.jpg'.format(cn)) , combined_image)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class GAN:
'''
OPTIONs
OPTIONS
z_dim : Noise dimension 100
t_dim : Text feature dimension 256
image_size : Image Dimension 64
Expand Down
1 change: 1 addition & 0 deletions skipthoughts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'''
Skip-thought vectors
https://github.com/ryankiros/skip-thoughts
'''
import os

Expand Down
27 changes: 18 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import scipy.misc
import random
import json
import os
import shutil

def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -88,7 +90,7 @@ def main():
for i in range(args.epochs):
batch_no = 0
while batch_no*args.batch_size < loaded_data['data_length']:
real_images, wrong_images, caption_vectors, z_noise = get_training_batch(batch_no, args.batch_size,
real_images, wrong_images, caption_vectors, z_noise, image_files = get_training_batch(batch_no, args.batch_size,
args.image_size, args.z_dim, args.caption_vector_length, 'train', args.data_dir, args.data_set, loaded_data)

# DISCR UPDATE
Expand Down Expand Up @@ -128,10 +130,10 @@ def main():
batch_no += 1
if (batch_no % args.save_every) == 0:
print "Saving Images, Model"
save_for_vis(args.data_dir, real_images, gen)
save_path = saver.save(sess, "Data/Models/model_{}_temp.ckpt".format(args.data_set))

save_path = saver.save(sess, "Data/Models/model_{}_epoch_{}.ckpt".format(args.data_set, i))
save_for_vis(args.data_dir, real_images, gen, image_files)
save_path = saver.save(sess, "Data/Models/latest_model_{}_temp.ckpt".format(args.data_set))
if i%5 == 0:
save_path = saver.save(sess, "Data/Models/model_after_{}_epoch_{}.ckpt".format(args.data_set, i))

def load_training_data(data_dir, data_set):
if data_set == 'flowers':
Expand All @@ -158,12 +160,15 @@ def load_training_data(data_dir, data_set):
# No preloading for MS-COCO
return meta_data

def save_for_vis(data_dir, real_images, generated_images):
def save_for_vis(data_dir, real_images, generated_images, image_files):

shutil.rmtree( join(data_dir, 'samples') )
os.makedirs( join(data_dir, 'samples') )

for i in range(0, real_images.shape[0]):
real_image_255 = np.zeros( (64,64,3), dtype=np.uint8)
real_images_255 = (real_images[i,:,:,:])
scipy.misc.imsave( join(data_dir, 'samples/real_image_{}.jpg'.format(i)) , real_images_255)
scipy.misc.imsave( join(data_dir, 'samples/{}_{}.jpg'.format(i, image_files[i].split('/')[-1] )) , real_images_255)

fake_image_255 = np.zeros( (64,64,3), dtype=np.uint8)
fake_images_255 = (generated_images[i,:,:,:])
Expand All @@ -182,10 +187,12 @@ def get_training_batch(batch_no, batch_size, image_size, z_dim,
real_images = np.zeros((batch_size, 64, 64, 3))
wrong_images = np.zeros((batch_size, 64, 64, 3))

image_files = []
for idx, image_id in enumerate(image_ids):
image_file = join(data_dir, '%s2014/COCO_%s2014_%.12d.jpg'%(split, split, image_id) )
image_array = image_processing.load_image_array(image_file, image_size)
real_images[idx,:,:,:] = image_array
image_files.append(image_file)

# TODO>> As of Now, wrong images are just shuffled real images.
first_image = real_images[0,:,:,:]
Expand All @@ -198,14 +205,15 @@ def get_training_batch(batch_no, batch_size, image_size, z_dim,
z_noise = np.random.uniform(-1, 1, [batch_size, z_dim])


return real_images, wrong_images, caption_vectors, z_noise
return real_images, wrong_images, caption_vectors, z_noise, image_files

if data_set == 'flowers':
real_images = np.zeros((batch_size, 64, 64, 3))
wrong_images = np.zeros((batch_size, 64, 64, 3))
captions = np.zeros((batch_size, caption_vector_length))

cnt = 0
image_files = []
for i in range(batch_no * batch_size, batch_no * batch_size + batch_size):
idx = i % len(loaded_data['image_list'])
image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][idx])
Expand All @@ -219,10 +227,11 @@ def get_training_batch(batch_no, batch_size, image_size, z_dim,

random_caption = random.randint(0,4)
captions[cnt,:] = loaded_data['captions'][ loaded_data['image_list'][idx] ][ random_caption ][0:caption_vector_length]
image_files.append( image_file )
cnt += 1

z_noise = np.random.uniform(-1, 1, [batch_size, z_dim])
return real_images, wrong_images, captions, z_noise
return real_images, wrong_images, captions, z_noise, image_files

if __name__ == '__main__':
main()

0 comments on commit 7e61f91

Please sign in to comment.