Skip to content

Commit

Permalink
Using feed dict for train and validate
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Dec 1, 2016
1 parent 66916a6 commit 9a68f03
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 52 deletions.
109 changes: 69 additions & 40 deletions cifar10/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,29 @@

def train():

# Placeholder inputs
images, labels = support.placeholder_inputs()

# Training data
train_images, train_labels = support.inputs(
train_images, train_labels = support.data_inputs(
FLAGS.datadir,
support.TRAINING_DATA,
FLAGS.batch_size)

# Validation data
validate_images, validate_labels = support.inputs(
validate_images, validate_labels = support.data_inputs(
FLAGS.datadir,
support.VALIDATION_DATA,
FLAGS.batch_size)

# Model and training ops
validate_flag = tf.placeholder(tf.bool, ())
predict = support.inference(train_images, validate_images, validate_flag)
loss = support.loss(predict, train_labels)
predict = support.inference(images)
loss = support.loss(predict, labels)
global_step = tf.Variable(0, trainable=False)
train, learning_rate = support.train(loss, FLAGS.batch_size, global_step)
train, learning_rate = support.train(loss, global_step, FLAGS.batch_size)

# Accuracy
accuracy = support.accuracy(
predict, train_labels, validate_labels, validate_flag)
accuracy = support.accuracy(predict, labels)

# Summaries
tf.scalar_summary("loss", loss)
Expand All @@ -51,50 +52,78 @@ def train():
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# Helpers to log status
def log_status(train_step, validate=False):
print("Step %i:" % train_step, end="")
log_train_status(train_step)
if validate:
log_validate_status(train_step)
print()
# Helper to read next batch
def next_batch(images_source, labels_source):
batch_images, batch_labels = sess.run([images_source, labels_source])
return {
images: batch_images,
labels: batch_labels
}

def log_train_status(step):
summaries_, accuracy_ = sess.run(
[summaries, accuracy],
feed_dict={validate_flag: False})
train_writer.add_summary(summaries_, step)
train_writer.flush()
print(" training=%f" % accuracy_, end="")

def log_validate_status(train_step):
accuracies = []
validate_steps = support.VALIDATION_IMAGES_COUNT // FLAGS.batch_size
# Helpder to validate
def validate():
step = 0
while step < validate_steps:
accuracy_ = sess.run(accuracy, feed_dict={validate_flag: True})
accuracies.append(accuracy_)
steps = support.VALIDATION_IMAGES_COUNT // FLAGS.batch_size
validate_accuracy = 0.0
while step < steps:
batch_accuracy = sess.run(
accuracy, next_batch(validate_images, validate_labels))
validate_accuracy += batch_accuracy / steps
step += 1
validation_accuracy = float(np.mean(accuracies))
summary = tf.Summary()
summary.value.add(tag="accuracy", simple_value=validation_accuracy)
validate_writer.add_summary(summary, train_step)
validate_writer.flush()
print(" validation=%f" % validation_accuracy, end="")
summary.value.add(tag="accuracy", simple_value=validate_accuracy)
return summary, validate_accuracy

# Helper to log status
def log_status(step, train_summary, train_accuracy,
validate_summary=None, validate_accuracy=None):
train_writer.add_summary(train_summary, step)
if validate_summary is not None:
validate_writer.add_summary(validate_summary)
validate_writer.flush()
print("Step %i: training=%f" % (step, train_accuracy), end="")
if validate_accuracy is not None:
print(" validation=%f" % validate_accuracy, end="")
print()

# Helper to save model
saver = tf.train.Saver()
def save_model():
print("Saving trained model")
tf.gfile.MakeDirs(FLAGS.rundir + "/model")
saver.save(sess, FLAGS.rundir + "/model/export")

# Training loop
steps_per_epoch = support.TRAINING_IMAGES_COUNT // FLAGS.batch_size
train_steps = steps_per_epoch * FLAGS.epochs
steps = steps_per_epoch * FLAGS.epochs
step = 0
while step < train_steps:
sess.run(train, feed_dict={validate_flag: False})
while step < steps:
_, train_summary, train_accuracy = sess.run(
[train, summaries, accuracy],
next_batch(train_images, train_labels))
if step % 20 == 0:
validate = step % steps_per_epoch == 0
log_status(step, validate)
if step % steps_per_epoch == 0:
validate_summary, validate_accuracy = validate()
log_status(
step, train_summary, train_accuracy,
validate_summary, validate_accuracy)
save_model()
else:
log_status(
step, train_summary, train_accuracy)
step += 1

# Final status
log_status(step, True)
train_summary, train_accuracy = sess.run(
[summaries, accuracy],
next_batch(train_images, train_labels))
validate_summary, validate_accuracy = validate()
log_status(
step, train_summary, train_accuracy,
validate_summary, validate_accuracy)

# Save trained model
save_model()

# Stop queue runners
coord.request_stop()
Expand Down
27 changes: 15 additions & 12 deletions cifar10/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def progress(count, block_size, total_size):
# Inputs
###################################################################

def inputs(data_dir, data_type, batch_size):
def placeholder_inputs():
images = tf.placeholder(tf.float32, [None, CROPPED_IMAGE_HEIGHT,
CROPPED_IMAGE_WIDTH, IMAGE_DEPTH])
labels = tf.placeholder(tf.int32, [None])
return images, labels

def data_inputs(data_dir, data_type, batch_size):

# Input file reader
filenames = input_filenames(data_dir, data_type)
Expand Down Expand Up @@ -112,9 +118,7 @@ def data_path(name):
# Inference
###################################################################

def inference(train_images, validate_images, validate):

images = tf.cond(validate, lambda: validate_images, lambda: train_images)
def inference(images):

# First convolutional layer
W_conv1 = weight_variable([5, 5, 3, 64], 0.05)
Expand Down Expand Up @@ -184,21 +188,20 @@ def loss(logits, labels):
# Train
###################################################################

def train(loss, batch_size, global_step):
def train(loss, global_step, batch_size):
batches_per_epoch = TRAINING_IMAGES_COUNT // batch_size
decay_steps = batches_per_epoch * EPOCHS_PER_DECAY
lr = tf.train.exponential_decay(
learning_rate = tf.train.exponential_decay(
0.1, global_step, decay_steps, 0.1, staircase=True)
optimizer = tf.train.GradientDescentOptimizer(lr)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
gradients = optimizer.compute_gradients(loss)
train = optimizer.apply_gradients(gradients, global_step=global_step)
return train, lr
return train, learning_rate

###################################################################
# Accuracy
###################################################################

def accuracy(logits, train_labels, validate_labels, validate):
labels = tf.cond(validate, lambda: validate_labels, lambda: train_labels)
top_k = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_mean(tf.cast(top_k, tf.float16))
def accuracy(logits, labels):
top_1 = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_mean(tf.cast(top_1, tf.float16))

0 comments on commit 9a68f03

Please sign in to comment.