From 66916a67f17249eda957f650c7f75b11fcf9465b Mon Sep 17 00:00:00 2001 From: Garrett Smith Date: Wed, 30 Nov 2016 18:27:33 -0600 Subject: [PATCH] Remove weird use of batch size in model definition --- cifar10/single.py | 3 +-- cifar10/support.py | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cifar10/single.py b/cifar10/single.py index 60de789..527142f 100644 --- a/cifar10/single.py +++ b/cifar10/single.py @@ -26,8 +26,7 @@ def train(): # Model and training ops validate_flag = tf.placeholder(tf.bool, ()) - predict = support.inference( - train_images, validate_images, FLAGS.batch_size, validate_flag) + predict = support.inference(train_images, validate_images, validate_flag) loss = support.loss(predict, train_labels) global_step = tf.Variable(0, trainable=False) train, learning_rate = support.train(loss, FLAGS.batch_size, global_step) diff --git a/cifar10/support.py b/cifar10/support.py index f6a3572..e9a7a73 100644 --- a/cifar10/support.py +++ b/cifar10/support.py @@ -112,7 +112,7 @@ def data_path(name): # Inference ################################################################### -def inference(train_images, validate_images, batch_size, validate): +def inference(train_images, validate_images, validate): images = tf.cond(validate, lambda: validate_images, lambda: train_images) @@ -131,9 +131,8 @@ def inference(train_images, validate_images, batch_size, validate): h_pool2 = max_pool_3x3(h_norm2) # First locally connected layer - h_pool2_flat = tf.reshape(h_pool2, [batch_size, -1]) - dim = h_pool2_flat.get_shape()[1].value - W_local1 = weight_variable([dim, 384], 0.04, 0.004) + h_pool2_flat = tf.reshape(h_pool2, [-1, 6 * 6 * 64]) + W_local1 = weight_variable([6 * 6 * 64, 384], 0.04, 0.004) b_local1 = bias_variable([384], 0.1) h_local1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_local1) + b_local1)