Skip to content

Commit

Permalink
Cleanup example code
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Nov 30, 2016
1 parent 339f525 commit 0be33f4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
11 changes: 6 additions & 5 deletions mnist/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def train(mnist):
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Helper to log performance
# Helper to log status
def log_model_status(step, train_images, train_labels):
train_data = {
x: train_images,
Expand All @@ -95,21 +95,22 @@ def log_model_status(step, train_images, train_labels):
print "Step %i: training=%f validation=%f" % (
step, train_accuracy, validation_accuracy)

# Helper to save model
saver = tf.train.Saver()
def save_model():
print "Saving trained model"
tf.gfile.MakeDirs(FLAGS.rundir + "/model")
saver.save(sess, FLAGS.rundir + "/model/export")

# Batch training over all training examples per epoch
steps = (mnist.train.num_examples // FLAGS.batch_size) * FLAGS.epochs
# Training loop
steps_per_epoch = mnist.train.num_examples // FLAGS.batch_size
steps = steps_per_epoch * FLAGS.epochs
for step in range(steps):
images, labels = mnist.train.next_batch(FLAGS.batch_size)
sess.run(train, feed_dict={x: images, y_: labels})
if step % 20 == 0:
log_model_status(step, images, labels)
if step != 0 and step % (mnist.train.num_examples /
FLAGS.batch_size) == 0:
if step % steps_per_epoch == 0:
save_model()

# Final status
Expand Down
11 changes: 6 additions & 5 deletions mnist/intro.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train(mnist):
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Helper to log performance
# Helper to log status
def log_model_status(step, train_images, train_labels):
train_data = {
x: train_images,
Expand All @@ -59,21 +59,22 @@ def log_model_status(step, train_images, train_labels):
print "Step %i: training=%f validation=%f" % (
step, train_accuracy, validation_accuracy)

# Helper to save model
saver = tf.train.Saver()
def save_model():
print "Saving trained model"
tf.gfile.MakeDirs(FLAGS.rundir + "/model")
saver.save(sess, FLAGS.rundir + "/model/export")

# Batch training over all training examples per epoch
steps = (mnist.train.num_examples // FLAGS.batch_size) * FLAGS.epochs
# Training loop
steps_per_epoch = mnist.train.num_examples // FLAGS.batch_size
steps = steps_per_epoch * FLAGS.epochs
for step in range(steps):
images, labels = mnist.train.next_batch(FLAGS.batch_size)
sess.run(train, feed_dict={x: images, y_: labels})
if step % 20 == 0:
log_model_status(step, images, labels)
if step != 0 and step % (mnist.train.num_examples /
FLAGS.batch_size) == 0:
if step % steps_per_epoch == 0:
save_model()

# Final status
Expand Down

0 comments on commit 0be33f4

Please sign in to comment.