Skip to content

Commit

Permalink
Validation accuracy to cifar10 example
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Nov 29, 2016
1 parent 68c8224 commit 676cc28
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
4 changes: 2 additions & 2 deletions cifar10/Guild
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ train = upstream_single
datadir = CIFAR10_data
rundir = $RUNDIR
epochs = 10
batch_size = 128
batch_size = 100

[view]

fields = train-accuracy steps time
fields = validation-accuracy train-accuracy steps time

series-a = accuracy

Expand Down
67 changes: 48 additions & 19 deletions cifar10/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import os

import numpy as np
import tensorflow as tf

import support
Expand All @@ -24,22 +25,24 @@ def train():
FLAGS.batch_size)

# Model and training ops
train_predict = support.inference(train_images, FLAGS.batch_size)
loss = support.loss(train_predict, train_labels)
validate_flag = tf.placeholder(tf.bool, ())
predict = support.inference(
train_images, validate_images, FLAGS.batch_size, validate_flag)
loss = support.loss(predict, train_labels)
global_step = tf.Variable(0, trainable=False)
train, learning_rate = support.train(loss, FLAGS.batch_size, global_step)

# Accuracy
train_accuracy = support.accuracy(train_predict, train_labels)
validate_predict = support.inference(validate_images, FLAGS.batch_size)
validate_accuracy = support.accuracy(validate_predict, validate_labels)
accuracy = support.accuracy(
predict, train_labels, validate_labels, validate_flag)

# Summaries
tf.scalar_summary("loss", loss)
tf.scalar_summary("accuracy", train_accuracy)
tf.scalar_summary("accuracy", accuracy)
tf.scalar_summary("learning_rate", learning_rate)
train_summaries = tf.merge_all_summaries()
summaries = tf.merge_all_summaries()
train_writer = tf.train.SummaryWriter(FLAGS.rundir + "/train")
validate_writer = tf.train.SummaryWriter(FLAGS.rundir + "/validation")

# Initialize session
sess = tf.Session()
Expand All @@ -49,24 +52,50 @@ def train():
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# Helper to log status
def log_status(step):
loss_, summaries, accuracy = sess.run(
[loss, train_summaries, train_accuracy])
train_writer.add_summary(summaries, step)
print("Step %i: loss=%f accuracy=%s" % (step, loss_, accuracy))
# Helpers to log status
def log_status(train_step, validate=False):
print("Step %i:" % train_step, end="")
log_train_status(train_step)
if validate:
log_validate_status(train_step)
print()

def log_train_status(step):
summaries_, accuracy_ = sess.run(
[summaries, accuracy],
feed_dict={validate_flag: False})
train_writer.add_summary(summaries_, step)
train_writer.flush()
print(" training=%f" % accuracy_, end="")

def log_validate_status(train_step):
accuracies = []
validate_steps = support.VALIDATION_IMAGES_COUNT // FLAGS.batch_size
step = 0
while step < validate_steps:
accuracy_ = sess.run(accuracy, feed_dict={validate_flag: True})
accuracies.append(accuracy_)
step += 1
validation_accuracy = float(np.mean(accuracies))
summary = tf.Summary()
summary.value.add(tag="accuracy", simple_value=validation_accuracy)
validate_writer.add_summary(summary, train_step)
validate_writer.flush()
print(" validation=%f" % validation_accuracy, end="")

# Training loop
steps = (support.TRAINING_IMAGES_COUNT // FLAGS.batch_size) * FLAGS.epochs
steps_per_epoch = support.TRAINING_IMAGES_COUNT // FLAGS.batch_size
train_steps = steps_per_epoch * FLAGS.epochs
step = 0
while step < steps:
sess.run(train)
while step < train_steps:
sess.run(train, feed_dict={validate_flag: False})
if step % 20 == 0:
log_status(step)
validate = step > 0 and step % steps_per_epoch == 0
log_status(step, validate)
step += 1

# Final status
log_status(step)
log_status(step, True)

# Stop queue runners
coord.request_stop()
Expand All @@ -88,7 +117,7 @@ def main(_):
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", default="/tmp/CIFAR10_data",)
parser.add_argument("--rundir", default="/tmp/CIFAR10_train")
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=100)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--prepare", action="store_true")
parser.add_argument("--evaluate", action="store_true")
Expand Down
16 changes: 10 additions & 6 deletions cifar10/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DATA_BIN_NAME = "cifar-10-batches-bin"

TRAINING_IMAGES_COUNT = 50000
TEST_IMAGES_COUNT = 10000
VALIDATION_IMAGES_COUNT = 10000
CLASS_COUNT = 10

IMAGE_HEIGHT = 32
Expand Down Expand Up @@ -99,19 +99,22 @@ def inputs(data_dir, data_type, batch_size):
return images, tf.reshape(labels, [batch_size])

def input_filenames(data_dir, data_type):
def data_path(name):
return os.path.join(data_dir, DATA_BIN_NAME, name)
if data_type == TRAINING_DATA:
return [os.path.join(data_dir, DATA_BIN_NAME, "data_batch_%i.bin" % i)
for i in range(1, 6)]
return [data_path("data_batch_%i.bin" % i) for i in range(1, 6)]
elif data_type == VALIDATION_DATA:
return [os.path.join(data_dir, DATA_BIN_NAME, "test_batch.bin")]
return [data_path("test_batch.bin")]
else:
raise ValueError(data_type)

###################################################################
# Inference
###################################################################

def inference(images, batch_size):
def inference(train_images, validate_images, batch_size, validate):

images = tf.cond(validate, lambda: validate_images, lambda: train_images)

# First convolutional layer
W_conv1 = weight_variable([5, 5, 3, 64], 0.05)
Expand Down Expand Up @@ -196,6 +199,7 @@ def train(loss, batch_size, global_step):
# Accuracy
###################################################################

def accuracy(logits, labels):
def accuracy(logits, train_labels, validate_labels, validate):
labels = tf.cond(validate, lambda: validate_labels, lambda: train_labels)
top_k = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_mean(tf.cast(top_k, tf.float16))

0 comments on commit 676cc28

Please sign in to comment.