Skip to content

Commit

Permalink
Switch for using data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Dec 21, 2016
1 parent 3aa553d commit 055430e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion cifar10/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def train():
# Training data
train_images, train_labels = support.data_inputs(
FLAGS.datadir,
support.TRAINING_DATA,
(support.AUGMENTED_TRAINING_DATA
if FLAGS.augment else support.TRAINING_DATA),
FLAGS.batch_size,
FLAGS.runner_threads)

Expand Down Expand Up @@ -212,5 +213,6 @@ def main(_):
parser.add_argument("--evaluate", action="store_true")
parser.add_argument("--runner_threads", type=int, default=4)
parser.add_argument("--decay_epochs", type=int, default=20)
parser.add_argument("--augment", action="store_true")
FLAGS, _ = parser.parse_known_args()
tf.app.run()
5 changes: 3 additions & 2 deletions cifar10/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
INPUT_RECORD_BYTES = INPUT_LABEL_BYTES + INPUT_IMAGE_BYTES

TRAINING_DATA = 1
VALIDATION_DATA = 2
AUGMENTED_TRAINING_DATA = 2
VALIDATION_DATA = 3

###################################################################
# Download data
Expand Down Expand Up @@ -87,7 +88,7 @@ def data_inputs(data_dir, data_type, batch_size, runner_threads):

# Finalize image
image_float = tf.cast(image_hwd, tf.float32)
if data_type == TRAINING_DATA:
if data_type == AUGMENTED_TRAINING_DATA:
image_final = augmented_standardized_image(image_float)
else:
image_final = standardized_image(image_float)
Expand Down

0 comments on commit 055430e

Please sign in to comment.