Skip to content

Commit

Permalink
Evaluate and serve support
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Dec 1, 2016
1 parent 3ce3594 commit 3fb7eda
Showing 1 changed file with 50 additions and 2 deletions.
52 changes: 50 additions & 2 deletions cifar10/single.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function

import argparse
import json
import os

import numpy as np
Expand All @@ -12,7 +13,7 @@

def train():

# Placeholder inputs
# Placeholders for inputs
images, labels = support.placeholder_inputs()

# Training data
Expand Down Expand Up @@ -44,6 +45,10 @@ def train():
train_writer = tf.train.SummaryWriter(FLAGS.rundir + "/train")
validate_writer = tf.train.SummaryWriter(FLAGS.rundir + "/validation")

# Inputs/outputs for running exported model
tf.add_to_collection("inputs", json.dumps({"image": images.name}))
tf.add_to_collection("outputs", json.dumps({"prediction": labels.name}))

# Initialize session
sess = tf.Session()
sess.run(tf.global_variables_initializer())
Expand Down Expand Up @@ -123,14 +128,57 @@ def save_model():
validate_summary, validate_accuracy)

# Save trained model
tf.add_to_collection("images", images.name)
tf.add_to_collection("labels", labels.name)
tf.add_to_collection("accuracy", accuracy.name)
save_model()

# Stop queue runners
coord.request_stop()
coord.join(threads)

def evaluate():
print("TODO evaluate")
# Validation data
validate_images, validate_labels = support.data_inputs(
FLAGS.datadir,
support.VALIDATION_DATA,
FLAGS.batch_size)

# Load model
sess = tf.Session()
saver = tf.train.import_meta_graph(FLAGS.rundir + "/model/export.meta")
saver.restore(sess, FLAGS.rundir + "/model/export")

# Tensors used to evaluate
images = sess.graph.get_tensor_by_name(tf.get_collection("images")[0])
labels = sess.graph.get_tensor_by_name(tf.get_collection("labels")[0])
accuracy = sess.graph.get_tensor_by_name(tf.get_collection("accuracy")[0])

# Initialize queue runners
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# Evaluate loop
step = 0
steps = support.VALIDATION_IMAGES_COUNT // FLAGS.batch_size
validate_accuracy = 0.0
while step < steps:
batch_images, batch_labels = sess.run(
[validate_images, validate_labels])
batch_input = {
images: batch_images,
labels: batch_labels
}
batch_accuracy = sess.run(accuracy, batch_input)
validate_accuracy += batch_accuracy / steps
step += 1

# Print validation accuracy
print("Validation accuracy: %f" % validate_accuracy)

# Stop queue runners
coord.request_stop()
coord.join(threads)

def main(_):
support.ensure_data(FLAGS.datadir)
Expand Down

0 comments on commit 3fb7eda

Please sign in to comment.