Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ace19-dev committed Aug 8, 2019
1 parent 8e95d80 commit fab0e2e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 103 deletions.
57 changes: 0 additions & 57 deletions datasets/create_query_dataset.py

This file was deleted.

2 changes: 1 addition & 1 deletion retrieval/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class NearestNeighborDistanceMetric(object):
"""

def __init__(self, metric, matching_threshold, budget=None):
def __init__(self, metric, matching_threshold=None, budget=None):


if metric == "euclidean":
Expand Down
4 changes: 2 additions & 2 deletions retrieval/retrieval_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(self, tfrecord_path, batch_size, num_classes, num_epochs, data_size
# The map transformation takes a function and applies it to every element
# of the self.dataset.
self.dataset = self.dataset.map(self.decode, num_parallel_calls=8)
# self.dataset = self.dataset.map(self.eval, num_parallel_calls=8)
self.dataset = self.dataset.map(self.tencrop, num_parallel_calls=8)
self.dataset = self.dataset.map(self.eval, num_parallel_calls=8)
# self.dataset = self.dataset.map(self.tencrop, num_parallel_calls=8)
self.dataset = self.dataset.map(self.normalize, num_parallel_calls=8)

# Prefetches a batch at a time to smooth out the time taken to load input
Expand Down
102 changes: 71 additions & 31 deletions retrieval/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tensorflow as tf

from retrieval import retrieval_data, matching
from utils import aug_utils
from utils import aug_utils, train_utils
import model
import train_data

Expand All @@ -21,23 +21,43 @@

# Dataset settings.
flags.DEFINE_string('dataset_dir',
'/home/ace19/dl_data/materials',
'/home/ace19/dl_data/v2-plant-seedlings-dataset-resized',
'Where the dataset reside.')

flags.DEFINE_string('output_dir',
'/home/ace19/dl_results/image_retrieve/_result',
'Where the dataset reside.')

flags.DEFINE_string('checkpoint_path',
'../tfmodels',
flags.DEFINE_string('pre_trained_checkpoint',
'../tfmodels/best.ckpt-5',
'Directory where to read training checkpoints.')
flags.DEFINE_string('checkpoint_exclude_scopes',
'ball/mean_vectors,ball/scale',
# None,
'Comma-separated list of scopes of variables to exclude '
'when restoring from a checkpoint.')
flags.DEFINE_string('checkpoint_model_scope',
'tower0/resnet_v2_50',
'Model scope in the checkpoint. None if the same as the trained model.')
flags.DEFINE_string('model_name',
'resnet_v2_50',
'The name of the architecture to train.')
flags.DEFINE_string('extra_model_name',
'fc1',
# None,
'The name of the architecture to extra train.')
flags.DEFINE_string('checkpoint_model_scope2',
'tower0/fc1',
'Model scope in the checkpoint. None if the same as the trained model.')

flags.DEFINE_integer('batch_size', 32, 'batch size')
flags.DEFINE_integer('height', 224, 'height')
flags.DEFINE_integer('width', 224, 'width')
# flags.DEFINE_string('labels',
# 'airplane,bed,bookshelf,toilet,vase',
# 'number of classes')
flags.DEFINE_string('labels',
'Black_grass,Charlock,Cleavers,Common_Chickweed,Common_wheat,Fat_Hen,'
'Loose_Silky_bent,Maize,Scentless_Mayweed,Shepherds_Purse,'
'Small_flowered_Cranesbill,Sugar_beet',
'Labels to use')

# # retrieval params
# flags.DEFINE_float('max_cosine_distance', 0.2,
Expand All @@ -47,8 +67,11 @@
# 'If None, no budget is enforced.')


GALLERY_SIZE = 43955
QUERY_SIZE = 300
# GALLERY_SIZE = 43955
# QUERY_SIZE = 300

GALLERY_SIZE = 263+384+285+606+215+457+648+219+516+233+490+393 # 4709
QUERY_SIZE = 60

TOP_N = 5
TEN_CROP = 10
Expand Down Expand Up @@ -142,11 +165,24 @@ def main(unused_argv):
is_reuse=False,
keep_prob=keep_prob,
attention_module='se_block')
features = tf.cond(is_training,
lambda: tf.identity(features),
lambda: tf.reduce_mean(tf.reshape(features, [FLAGS.batch_size, TEN_CROP, -1]), axis=1))

# Prepare query source data
# Print name and shape of parameter nodes (values not yet initialized)
tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
tf.compat.v1.logging.info("Parameters")
tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
for v in slim.get_model_variables():
tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

# features = tf.cond(is_training,
# lambda: tf.identity(features),
# lambda: tf.reduce_mean(tf.reshape(features, [FLAGS.batch_size, TEN_CROP, -1]), axis=1))

# Create a saver object which will save all the variables
saver = tf.compat.v1.train.Saver()

###############
# Prepare data
###############
tfrecord_filenames = tf.placeholder(tf.string, shape=[])
gallery_dataset = train_data.Dataset(tfrecord_filenames,
FLAGS.batch_size,
Expand All @@ -163,24 +199,28 @@ def main(unused_argv):
num_classes,
None,
QUERY_SIZE,
256, # 256 ~ 480
256)
FLAGS.height,
FLAGS.width)
# 256, # 256 ~ 480
# 256)
query_iterator = query_dataset.dataset.make_initializable_iterator()
query_next_batch = query_iterator.get_next()



sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.Session(config=sess_config) as sess:
sess_config = tf.compat.v1.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.compat.v1.Session(config=sess_config) as sess:
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
saver.restore(sess, checkpoint_path)
# TODO: supports multi gpu - add scope ('tower%d' % gpu_idx)
if FLAGS.pre_trained_checkpoint:
train_utils.restore_fn(FLAGS)

# if FLAGS.pre_trained_checkpoint:
# if tf.gfile.IsDirectory(FLAGS.pre_trained_checkpoint):
# checkpoint_path = tf.train.latest_checkpoint(FLAGS.pre_trained_checkpoint)
# else:
# checkpoint_path = FLAGS.pre_trained_checkpoint
# saver.restore(sess, checkpoint_path)

# global_step = checkpoint_path.split('/')[-1].split('-')[-1]

Expand Down Expand Up @@ -208,7 +248,7 @@ def main(unused_argv):

# (10,512)
_f = sess.run(features, feed_dict={X: gallery_batch_xs,
is_training:True,
is_training:False,
keep_prob: 1.0})
gallery_features_list.extend(_f)
gallery_path_list.extend(filenames)
Expand All @@ -221,13 +261,13 @@ def main(unused_argv):
filenames, query_batch_xs, query_batch_ys = sess.run(query_next_batch)
# show_batch_data(filenames, query_batch_xs, query_batch_ys)

# TTA
batch_size, n_crops, c, h, w = query_batch_xs.shape
# fuse batch size and ncrops
tencrop_query_batch_xs = np.reshape(query_batch_xs, (-1, c, h, w))
# # TTA
# batch_size, n_crops, c, h, w = query_batch_xs.shape
# # fuse batch size and ncrops
# tencrop_query_batch_xs = np.reshape(query_batch_xs, (-1, c, h, w))

# (10,512)
_f = sess.run(features, feed_dict={X: tencrop_query_batch_xs,
_f = sess.run(features, feed_dict={X: query_batch_xs,
is_training:False,
keep_prob: 1.0})
query_features_list.extend(_f)
Expand Down
21 changes: 10 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
'Learning rate policy for training.')
flags.DEFINE_float('base_learning_rate', 0.003,
flags.DEFINE_float('base_learning_rate', 0.002,
'The base learning rate for model training.')
flags.DEFINE_float('learning_rate_decay_factor', 1e-4,
'The rate to decay the base learning rate.')
Expand Down Expand Up @@ -72,10 +72,6 @@
# './tfmodels',
None,
'Saved checkpoint dir.')
# flags.DEFINE_string('saved_checkpoint_path',
# # './tfmodels/best_resnet_v2_50.ckpt',
# None,
# 'Saved checkpoint path.')
flags.DEFINE_string('pre_trained_checkpoint',
'pre-trained/resnet_v2_50.ckpt',
# None,
Expand Down Expand Up @@ -172,12 +168,6 @@ def main(unused_argv):
# for k, v in end_points.items():
# tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))
#
# # # Print name and shape of parameter nodes (values not yet initialized)
# # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
# # tf.compat.v1.logging.info("Parameters")
# # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
# # for v in slim.get_model_variables():
# # tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

# Gather initial summaries.
summaries = set(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))
Expand Down Expand Up @@ -219,6 +209,14 @@ def main(unused_argv):
is_reuse=False,
keep_prob=keep_prob,
attention_module='se_block')

# Print name and shape of parameter nodes (values not yet initialized)
tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
tf.compat.v1.logging.info("Parameters")
tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
for v in slim.get_model_variables():
tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

# TTA
logit = tf.cond(is_training,
lambda: tf.identity(logit),
Expand Down Expand Up @@ -335,6 +333,7 @@ def main(unused_argv):
train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir, graph)
validation_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/validation', graph)

# TODO: supports multi gpu - add scope ('tower%d' % gpu_idx)
if FLAGS.pre_trained_checkpoint:
train_utils.restore_fn(FLAGS)

Expand Down
7 changes: 6 additions & 1 deletion utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,12 @@ def restore_fn(flags):
flags.checkpoint_model_scope): var
for var in variables_to_restore}

# supports multi gpu?
# # modify
# if flags.extra_model_name is not None:
# for key, var in variables_to_restore.items():
# if key.split('/')[0] == flags.extra_model_name:
# key.replace(flags.extra_model_name, flags.checkpoint_model_scope2)

slim.assign_from_checkpoint_fn(flags.pre_trained_checkpoint,
variables_to_restore)
tf.compat.v1.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' %
Expand Down

0 comments on commit fab0e2e

Please sign in to comment.