Skip to content

Commit

Permalink
Expose lr decay rate and support for samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Dec 2, 2016
1 parent 35f37a9 commit 6a99a73
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 13 deletions.
1 change: 1 addition & 0 deletions cifar10/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
CIFAR10_data
samples
11 changes: 11 additions & 0 deletions cifar10/Guild
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ rundir = $RUNDIR
epochs = 10
batch_size = 100
runner_threads = 4
decay_epochs = 100

[resource "samples"]

runtime = tensorflow
prepare = samples

[flags "samples"]

sample_count = 100
sample_dir = samples

[view]

Expand Down
61 changes: 61 additions & 0 deletions cifar10/samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse
import json

import tensorflow as tf

from tensorflow.core.framework import summary_pb2

import support

FLAGS = None

def prepare_samples(images_op, labels_op):
# Use summary op to generate a PNG from TF native image
inputs = tf.placeholder(tf.float32, [None, support.CROPPED_IMAGE_HEIGHT,
support.CROPPED_IMAGE_WIDTH,
support.IMAGE_DEPTH])
summary = tf.image_summary('input', inputs, 1)

# Init session
sess = tf.Session()

# Initialize queue runners
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# Save images and labels
images, labels = sess.run([images_op, labels_op])
i = 1
for image, label in zip(images, labels):
summary_bin = sess.run(summary, feed_dict={inputs: [image]})
image_summary = summary_pb2.Summary()
image_summary.ParseFromString(summary_bin)
basename = FLAGS.sample_dir + "/" + ("%05i" % i)
image_path = basename + ".png"
print "Writing %s" % image_path
with open(image_path, "w") as f:
f.write(image_summary.value[0].image.encoded_image_string)
with open(basename + ".json", "w") as f:
f.write(json.dumps({
"image": image.tolist(),
"label": int(label)
}))
i += 1

# Stop queue runners
coord.request_stop()
coord.join(threads)

def main(_):
images, labels = support.data_inputs(
FLAGS.datadir, support.VALIDATION_DATA, FLAGS.sample_count, 1)
tf.gfile.MakeDirs(FLAGS.sample_dir)
prepare_samples(images, labels)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", default="/tmp/CIFAR10_data",)
parser.add_argument("--sample_dir", default="/tmp/CIFAR10_samples")
parser.add_argument("--sample_count", type=int, default=100)
FLAGS, _ = parser.parse_known_args()
tf.app.run()
23 changes: 17 additions & 6 deletions cifar10/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,22 @@ def train():
FLAGS.batch_size,
FLAGS.runner_threads)

# Model and training ops
# Model
predict = support.inference(images)

# Training loss
loss = support.loss(predict, labels)

# Global step - syncs training and learning rate decay
global_step = tf.Variable(0, trainable=False)
train, learning_rate = support.train(loss, global_step, FLAGS.batch_size)

# Learning rate (decays)
steps_per_epoch = support.TRAINING_IMAGES_COUNT // FLAGS.batch_size
decay_steps = steps_per_epoch * FLAGS.decay_epochs
learning_rate = support.learning_rate(global_step, decay_steps)

# Training op
train = support.train(loss, learning_rate, global_step)

# Accuracy
accuracy = support.accuracy(predict, labels)
Expand All @@ -49,7 +60,7 @@ def train():

# Inputs/outputs for running exported model
tf.add_to_collection("inputs", json.dumps({"image": images.name}))
tf.add_to_collection("outputs", json.dumps({"prediction": labels.name}))
tf.add_to_collection("outputs", json.dumps({"prediction": predict.name}))

# Initialize session
sess = tf.Session()
Expand Down Expand Up @@ -101,7 +112,6 @@ def save_model():
saver.save(sess, FLAGS.rundir + "/model/export")

# Training loop
steps_per_epoch = support.TRAINING_IMAGES_COUNT // FLAGS.batch_size
steps = steps_per_epoch * FLAGS.epochs
step = 0
while step < steps:
Expand Down Expand Up @@ -196,10 +206,11 @@ def main(_):
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", default="/tmp/CIFAR10_data",)
parser.add_argument("--rundir", default="/tmp/CIFAR10_train")
parser.add_argument("--batch-size", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--prepare", action="store_true")
parser.add_argument("--evaluate", action="store_true")
parser.add_argument("--runner-threads", type=int, default=4)
parser.add_argument("--runner_threads", type=int, default=4)
parser.add_argument("--decay_epochs", type=int, default=20)
FLAGS, _ = parser.parse_known_args()
tf.app.run()
12 changes: 5 additions & 7 deletions cifar10/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
TRAINING_DATA = 1
VALIDATION_DATA = 2

EPOCHS_PER_DECAY = 10

###################################################################
# Download data
###################################################################
Expand Down Expand Up @@ -203,15 +201,15 @@ def loss(logits, labels):
# Train
###################################################################

def train(loss, global_step, batch_size):
batches_per_epoch = TRAINING_IMAGES_COUNT // batch_size
decay_steps = batches_per_epoch * EPOCHS_PER_DECAY
learning_rate = tf.train.exponential_decay(
def learning_rate(global_step, decay_steps):
return tf.train.exponential_decay(
0.1, global_step, decay_steps, 0.1, staircase=True)

def train(loss, learning_rate, global_step):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
gradients = optimizer.compute_gradients(loss)
train = optimizer.apply_gradients(gradients, global_step=global_step)
return train, learning_rate
return train

###################################################################
# Accuracy
Expand Down

0 comments on commit 6a99a73

Please sign in to comment.