Skip to content

Commit

Permalink
Remove weird use of batch size in model definition
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Dec 1, 2016
1 parent 339f525 commit 66916a6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
3 changes: 1 addition & 2 deletions cifar10/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def train():

# Model and training ops
validate_flag = tf.placeholder(tf.bool, ())
predict = support.inference(
train_images, validate_images, FLAGS.batch_size, validate_flag)
predict = support.inference(train_images, validate_images, validate_flag)
loss = support.loss(predict, train_labels)
global_step = tf.Variable(0, trainable=False)
train, learning_rate = support.train(loss, FLAGS.batch_size, global_step)
Expand Down
7 changes: 3 additions & 4 deletions cifar10/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def data_path(name):
# Inference
###################################################################

def inference(train_images, validate_images, batch_size, validate):
def inference(train_images, validate_images, validate):

images = tf.cond(validate, lambda: validate_images, lambda: train_images)

Expand All @@ -131,9 +131,8 @@ def inference(train_images, validate_images, batch_size, validate):
h_pool2 = max_pool_3x3(h_norm2)

# First locally connected layer
h_pool2_flat = tf.reshape(h_pool2, [batch_size, -1])
dim = h_pool2_flat.get_shape()[1].value
W_local1 = weight_variable([dim, 384], 0.04, 0.004)
h_pool2_flat = tf.reshape(h_pool2, [-1, 6 * 6 * 64])
W_local1 = weight_variable([6 * 6 * 64, 384], 0.04, 0.004)
b_local1 = bias_variable([384], 0.1)
h_local1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_local1) + b_local1)

Expand Down

0 comments on commit 66916a6

Please sign in to comment.