Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ace19-dev committed Aug 9, 2019
1 parent 22f56ef commit ae43ed3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 40 deletions.
20 changes: 6 additions & 14 deletions retrieval/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
'/home/ace19/dl_results/image_retrieve/_result',
'Where the dataset reside.')

flags.DEFINE_string('pre_trained_checkpoint',
flags.DEFINE_string('checkpoint_dir',
'../tfmodels/best.ckpt-5',
'Directory where to read training checkpoints.')
flags.DEFINE_string('checkpoint_exclude_scopes',
Expand All @@ -43,13 +43,6 @@
flags.DEFINE_string('model_name',
'resnet_v2_50',
'The name of the architecture to train.')
# flags.DEFINE_string('extra_model_name',
# 'fc1',
# # None,
# 'The name of the architecture to extra train.')
# flags.DEFINE_string('checkpoint_model_scope2',
# 'tower0/fc1',
# 'Model scope in the checkpoint. None if the same as the trained model.')

flags.DEFINE_integer('batch_size', 64, 'batch size')
flags.DEFINE_integer('height', 224, 'height')
Expand Down Expand Up @@ -211,15 +204,14 @@ def main(unused_argv):
with tf.compat.v1.Session(config=sess_config) as sess:
sess.run(tf.global_variables_initializer())

# TODO: supports multi gpu - add scope ('tower%d' % gpu_idx)
if FLAGS.pre_trained_checkpoint:
if FLAGS.checkpoint_dir:
train_utils.custom_restore_fn(FLAGS)

# if FLAGS.pre_trained_checkpoint:
# if tf.gfile.IsDirectory(FLAGS.pre_trained_checkpoint):
# checkpoint_path = tf.train.latest_checkpoint(FLAGS.pre_trained_checkpoint)
# if FLAGS.checkpoint_dir:
# if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
# checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
# else:
# checkpoint_path = FLAGS.pre_trained_checkpoint
# checkpoint_path = FLAGS.checkpoint_dir
# saver.restore(sess, checkpoint_path)

# global_step = checkpoint_path.split('/')[-1].split('-')[-1]
Expand Down
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,12 @@ def main(unused_argv):
keep_prob=keep_prob,
attention_module='se_block')

# Print name and shape of parameter nodes (values not yet initialized)
tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
tf.compat.v1.logging.info("Parameters")
tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
for v in slim.get_model_variables():
tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))
# # Print name and shape of parameter nodes (values not yet initialized)
# tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
# tf.compat.v1.logging.info("Parameters")
# tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
# for v in slim.get_model_variables():
# tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

# # TTA
# logit = tf.cond(is_training,
Expand Down Expand Up @@ -333,7 +333,7 @@ def main(unused_argv):
train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir, graph)
validation_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/validation', graph)

# TODO: supports multi gpu - add scope ('tower%d' % gpu_idx)
# TODO: supports multi gpu -> add scope ('tower%d' % gpu_idx)
if FLAGS.pre_trained_checkpoint:
train_utils.restore_fn(FLAGS)

Expand Down
12 changes: 6 additions & 6 deletions train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ def distort_image(self, filename, image, label):
color_ordering = random.randint(0, 1)
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_saturation(image, lower=0.2, upper=1.2)
image = tf.image.random_hue(image, max_delta=0.1)
image = tf.image.random_contrast(image, lower=0.2, upper=1.2)
elif color_ordering == 1:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.2, upper=1.2)
image = tf.image.random_saturation(image, lower=0.2, upper=1.2)
image = tf.image.random_hue(image, max_delta=0.1)

# The random_* ops do not necessarily clamp.
image = tf.clip_by_value(image, 0.0, 1.0)
Expand Down
14 changes: 1 addition & 13 deletions utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,6 @@ def restore_fn(flags):
flags.checkpoint_model_scope): var
for var in variables_to_restore}

# # modify
# if flags.extra_model_name is not None:
# for key, var in variables_to_restore.items():
# if key.split('/')[0] == flags.extra_model_name:
# key.replace(flags.extra_model_name, flags.checkpoint_model_scope2)

slim.assign_from_checkpoint_fn(flags.pre_trained_checkpoint,
variables_to_restore)
tf.compat.v1.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' %
Expand Down Expand Up @@ -383,7 +377,7 @@ def custom_restore_fn(flags):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
variables_to_restore.append('tower0/' + var)

# Change model scope if necessary.
if flags.checkpoint_model_scope is not None:
Expand All @@ -392,12 +386,6 @@ def custom_restore_fn(flags):
flags.checkpoint_model_scope): var
for var in variables_to_restore}

# # modify
# if flags.extra_model_name is not None:
# for key, var in variables_to_restore.items():
# if key.split('/')[0] == flags.extra_model_name:
# key.replace(flags.extra_model_name, flags.checkpoint_model_scope2)

slim.assign_from_checkpoint_fn(flags.pre_trained_checkpoint,
variables_to_restore)
tf.compat.v1.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' %
Expand Down

0 comments on commit ae43ed3

Please sign in to comment.