Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ace19-dev committed Jul 31, 2019
2 parents 342e6b2 + 72ca747 commit 530e517
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


## TODO
- [On progress] Multi GPU
- TTA
- [On progress] Balanced Batch
- [testing...] Multi GPU
- [testing...] TTA
- [developing...] Balanced Batch

## References
- https://github.com/ildoonet/tf-mobilenet-v2
Expand Down
19 changes: 14 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
TRAIN_DATA_SIZE = 263+384+285+606+215+457+648+219+516+233+490+393 # 4709
VALIDATE_DATA_SIZE = 46+68+50+107+38+81+114+38+91+41+86+70 # 830

TEN_CROP = 10


def show_batch_data(filenames, batch_x, batch_y, additional_path=None):
default_path = '/home/ace19/Pictures/'
Expand Down Expand Up @@ -218,11 +220,13 @@ def main(unused_argv):
# is_training=is_training,
# keep_prob=keep_prob,
# attention_module='se_block')

logit = tf.cond(is_training,
lambda: tf.identity(logit),
lambda: tf.reduce_mean(tf.reshape(logit, [FLAGS.val_batch_size, TEN_CROP, -1]), axis=1))
logits.append(logit)

l = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=ground_truth,
logits=logit)
logits=logit)
losses.append(l)
loss_w_reg = tf.reduce_sum(l) + tf.add_n(slim.losses.get_regularization_losses(scope=scope_name))

Expand Down Expand Up @@ -401,16 +405,21 @@ def main(unused_argv):
sess.run(val_iterator.initializer, feed_dict={tfrecord_filenames: validate_record_filenames})
for step in range(val_batches):
filenames, validation_batch_xs, validation_batch_ys = sess.run(val_next_batch)
# show_batch_data(filenames, validation_batch_xs, validation_batch_ys)
# TTA
batch_size, n_crops, c, h, w = validation_batch_xs.shape
# fuse batch size and ncrops
tencrop_val_batch_xs = np.reshape(validation_batch_xs, (-1, c, h, w))
# show_batch_data(filenames, tencrop_val_batch_xs, validation_batch_ys)

# augmented_val_batch_xs = aug_utils.aug(validation_batch_xs)
# augmented_val_batch_xs = aug_utils.aug(tencrop_val_batch_xs)
# show_batch_data(filenames, augmented_val_batch_xs,
# validation_batch_ys, 'aug')

# TODO: Verify TTA(TenCrop)
val_summary, val_loss, val_top1_acc, _confusion_matrix = sess.run(
[summary_op, loss, top1_acc, confusion_matrix],
feed_dict={
X: validation_batch_xs,
X: tencrop_val_batch_xs,
ground_truth: validation_batch_ys,
is_training: False,
keep_prob: 1.0
Expand Down
17 changes: 17 additions & 0 deletions val_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
MEAN=[0.485, 0.456, 0.406]
STD=[0.229, 0.224, 0.225]

RANDOM_CROP_SIZE = 200


class Dataset(object):
"""
Expand All @@ -30,6 +32,7 @@ def __init__(self, tfrecord_path, batch_size, num_classes, num_epochs, data_size
# of the self.dataset.
self.dataset = self.dataset.map(self.decode, num_parallel_calls=8)
# self.dataset = self.dataset.map(self.augment, num_parallel_calls=8)
self.dataset = self.dataset.map(self.tencrop, num_parallel_calls=8)
self.dataset = self.dataset.map(self.normalize, num_parallel_calls=8)

# Prefetches a batch at a time to smooth out the time taken to load input
Expand Down Expand Up @@ -80,6 +83,20 @@ def augment(self, filename, image, label):
return filename, image, label


def tencrop(self, filename, image, label):
"""Placeholder for TenCrop
horizontal flipping is used by default
"""
images = []
for i in range(5):
img = tf.random_crop(image, [RANDOM_CROP_SIZE, RANDOM_CROP_SIZE, 3])
img = tf.image.resize(img, [self.resize_h, self.resize_w])
images.append(img)
images.append(tf.image.flip_left_right(img))

return filename, tf.stack(images), label


def normalize(self, filename, image, label):
# """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
# image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
Expand Down

0 comments on commit 530e517

Please sign in to comment.