Skip to content

Commit

Permalink
Training for flowers dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
paarthneekhara committed Aug 22, 2016
1 parent 32ef023 commit d287297
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 27 deletions.
58 changes: 54 additions & 4 deletions data_loaderv2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
import os
from os.path import join, isfile
import re
import numpy as np
import pickle
import argparse
import skipthoughts
import h5py

def save_caption_vectors(data_dir, split, batch_size):
import h5py

def save_caption_vectors_ms_coco(data_dir, split, batch_size):
meta_data = {}
ic_file = join(data_dir, 'annotations/captions_{}2014.json'.format(split))
with open(ic_file) as f:
Expand Down Expand Up @@ -46,6 +46,51 @@ def save_caption_vectors(data_dir, split, batch_size):
print "Batches Done", batch_no, len(ic_data['annotations'])/batch_size
batch_no += 1

def save_caption_vectors_flowers(data_dir):
import time

img_dir = join(data_dir, 'flowers/jpg')
image_files = [f for f in os.listdir(img_dir) if 'jpg' in f]
print image_files[300:400]
print len(image_files)
image_captions = { img_file : [] for img_file in image_files }

caption_dir = join(data_dir, 'flowers/text_c10')
class_dirs = []
for i in range(1, 103):
class_dir_name = 'class_%.5d'%(i)
class_dirs.append( join(caption_dir, class_dir_name))

for class_dir in class_dirs:
caption_files = [f for f in os.listdir(class_dir) if 'txt' in f]
for cap_file in caption_files:
with open(join(class_dir,cap_file)) as f:
captions = f.read().split('\n')
img_file = cap_file[0:11] + ".jpg"
# 5 captions per image
image_captions[img_file] += [cap for cap in captions if len(cap) > 0][0:5]

print len(image_captions)

model = skipthoughts.load_model()
encoded_captions = {}


for i, img in enumerate(image_captions):
st = time.time()
encoded_captions[img] = skipthoughts.encode(model, image_captions[img])
print i, len(image_captions), img
print "Seconds", time.time() - st


h = h5py.File(join(data_dir, 'flower_tv.hdf5'))
for key in encoded_captions:
h.create_dataset(key, data=encoded_captions[key])
h.close()




def main():
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, default='train',
Expand All @@ -54,9 +99,14 @@ def main():
help='Data directory')
parser.add_argument('--batch_size', type=int, default=64,
help='Batch Size')
parser.add_argument('--data_set', type=str, default='flowers',
help='Data Set : Flowers, MS-COCO')
args = parser.parse_args()

save_caption_vectors(args.data_dir, args.split, args.batch_size)
if args.data_set == 'flowers':
save_caption_vectors_flowers(args.data_dir)
else:
save_caption_vectors_ms_coco(args.data_dir, args.split, args.batch_size)

if __name__ == '__main__':
main()
101 changes: 78 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
from Utils import image_processing
import scipy.misc
import random

def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -35,11 +36,13 @@ def main():
parser.add_argument('--beta1', type=float, default=0.5,
help='beta1')

parser.add_argument('--epochs', type=int, default=100,
parser.add_argument('--epochs', type=int, default=600,
help='epochs')
parser.add_argument('--resume_model', type=str, default=None,
help='Trained Model Path')

parser.add_argument('--data_set', type=str, default="flowers",
help='Dat set, :MS-COC, flowers')

args = parser.parse_args()
model_options = {
Expand Down Expand Up @@ -69,11 +72,17 @@ def main():
if args.resume_model:
saver.restore(sess, args.resume_model)

loaded_data = load_training_data(args.data_dir, args.data_set)
if args.data_set == 'flowers':
meta_data['data_length'] = len(loaded_data['image_list'])

print "Meta Data", meta_data['data_length']

for i in range(args.epochs):
batch_no = 0
while batch_no*args.batch_size < meta_data['data_length']:
real_images, wrong_images, caption_vectors, z_noise = get_training_batch(batch_no, args.batch_size,
args.image_size, args.z_dim, 'train', args.data_dir)
args.image_size, args.z_dim, 'train', args.data_dir, args.data_set, loaded_data)

# DISCR UPDATE
check_ts = [ checks['d_loss1'] , checks['d_loss2'], checks['d_loss3']]
Expand Down Expand Up @@ -114,12 +123,34 @@ def main():


batch_no += 1
if (batch_no % 100) == 0:
if (batch_no % 30) == 0:
print "Saving Images, Model"
save_for_vis(args.data_dir, real_images, gen)
save_path = saver.save(sess, "Data/Models/model_temp.ckpt")
print "LOSSES", d_loss, g_loss, batch_no, i
save_path = saver.save(sess, "Data/Models/model_epoch{}.ckpt".format(i))
save_path = saver.save(sess, "Data/Models/model_{}_temp.ckpt".format(args.data_set))
print "LOSSES", d_loss, g_loss, batch_no, i, len(loaded_data['image_list'])/ args.batch_size
save_path = saver.save(sess, "Data/Models/model_{}_epoch_{}.ckpt".format(args.data_set, i))

def load_training_data(data_dir, data_set):
if data_set == 'flowers':
h = h5py.File(join(data_dir, 'flower_tv.hdf5'))
flower_captions = {}
for ds in h.iteritems():
flower_captions[ds[0]] = np.array(ds[1])
image_list = [key for key in flower_captions]
image_list.sort()

img_75 = int(len(image_list)*0.75)
training_image_list = image_list[0:img_75]
random.shuffle(training_image_list)

return {
'image_list' : training_image_list,
'captions' : flower_captions
}

else:
# No preloading for MS-COCO
return None

def save_for_vis(data_dir, real_images, generated_images):

Expand All @@ -133,31 +164,55 @@ def save_for_vis(data_dir, real_images, generated_images):
scipy.misc.imsave(join(data_dir, 'samples/fake_image_{}.jpg'.format(i)), fake_images_255)


def get_training_batch(batch_no, batch_size, image_size, z_dim, split, data_dir):
with h5py.File( join(data_dir, 'tvs/'+split + '_tvs_' + str(batch_no))) as hf:
caption_vectors = np.array(hf.get('tv'))
caption_vectors = caption_vectors[:,0:2400]
with h5py.File( join(data_dir, 'tvs/'+split + '_tv_image_id_' + str(batch_no))) as hf:
image_ids = np.array(hf.get('tv'))
def get_training_batch(batch_no, batch_size, image_size, z_dim, split, data_dir, data_set, loaded_data = None):
if data_set == 'mscoco':
with h5py.File( join(data_dir, 'tvs/'+split + '_tvs_' + str(batch_no))) as hf:
caption_vectors = np.array(hf.get('tv'))
caption_vectors = caption_vectors[:,0:2400]
with h5py.File( join(data_dir, 'tvs/'+split + '_tv_image_id_' + str(batch_no))) as hf:
image_ids = np.array(hf.get('tv'))

real_images = np.zeros((batch_size, 64, 64, 3))
wrong_images = np.zeros((batch_size, 64, 64, 3))

for idx, image_id in enumerate(image_ids):
image_file = join(data_dir, '%s2014/COCO_%s2014_%.12d.jpg'%(split, split, image_id) )
image_array = image_processing.load_image_array(image_file, image_size)
real_images[idx,:,:,:] = image_array
real_images = np.zeros((batch_size, 64, 64, 3))
wrong_images = np.zeros((batch_size, 64, 64, 3))
for idx, image_id in enumerate(image_ids):
image_file = join(data_dir, '%s2014/COCO_%s2014_%.12d.jpg'%(split, split, image_id) )
image_array = image_processing.load_image_array(image_file, image_size)
real_images[idx,:,:,:] = image_array

for i in range(0, batch_size):
wrong_images[i,:,:,:] = real_images[batch_size-i-1,:,:,:]
# BUGGGGGGGGGGGGGGGGGGGGG
for i in range(0, batch_size):
wrong_images[i,:,:,:] = real_images[batch_size-i-1,:,:,:]

z_noise = np.random.uniform(-1, 1, [batch_size, z_dim])
z_noise = np.random.uniform(-1, 1, [batch_size, z_dim])


return real_images, wrong_images, caption_vectors, z_noise
return real_images, wrong_images, caption_vectors, z_noise

if data_set == 'flowers':
real_images = np.zeros((batch_size, 64, 64, 3))
wrong_images = np.zeros((batch_size, 64, 64, 3))
captions = np.zeros((batch_size, 2400))

cnt = 0
for i in range(batch_no * batch_size, batch_no * batch_size + batch_size):
idx = i % len(loaded_data['image_list'])
image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][idx])
image_array = image_processing.load_image_array(image_file, image_size)
real_images[cnt,:,:,:] = image_array

wrong_image_id = random.randint(0,len(loaded_data['image_list'])-1)
wrong_image_file = join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][wrong_image_id])
wrong_image_array = image_processing.load_image_array(wrong_image_file, image_size)
wrong_images[cnt, :,:,:] = wrong_image_array

random_caption = random.randint(0,4)
captions[cnt,:] = loaded_data['captions'][ loaded_data['image_list'][idx] ][ random_caption ][0:2400]

cnt += 1

z_noise = np.random.uniform(-1, 1, [batch_size, z_dim])
return real_images, wrong_images, captions, z_noise

if __name__ == '__main__':
main()

0 comments on commit d287297

Please sign in to comment.