Skip to content

Commit

Permalink
Fixes some issues with saving and loading checkpoints
Browse files Browse the repository at this point in the history
- Use a generic saver for ease of using when using Scaffold (in both train and eval)
- Init variables and then load the pretrained weights (it's important that these two operations are executed in order)
- Remove `--continue-training` option, since it's the default behaviour now
- Use different GraphKey for grabbing pretrained variables. Using GLOBAL yielded optimizer variables which we didn't want to load
  • Loading branch information
vierja committed Sep 20, 2017
1 parent 35999e6 commit 1193b81
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 20 deletions.
4 changes: 2 additions & 2 deletions luminoth/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def evaluate(model_type, dataset_split, config_file, job_dir, watch,
tf.local_variables_initializer()
)

# Get the saver required to load model parameters.
saver = model.get_saver()
# Using a global saver instead of the one for the model.
saver = tf.train.Saver(sharded=True, allow_empty=True)

# Aggregate the required ops to evaluate into a dict..
ops = {
Expand Down
2 changes: 1 addition & 1 deletion luminoth/models/pretrained/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_weights(self, checkpoint_file):
return tf.no_op(name='not_loading_pretrained')

module_variables = snt.get_variables_in_module(
self, tf.GraphKeys.GLOBAL_VARIABLES
self, tf.GraphKeys.MODEL_VARIABLES
)
assert len(module_variables) > 0

Expand Down
28 changes: 11 additions & 17 deletions luminoth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
)


def run(model_type, config_file, override_params, continue_training, seed,
target='', cluster_spec=None, is_chief=True, job_name=None,
task_index=None, **kwargs):
def run(model_type, config_file, override_params, seed, target='',
cluster_spec=None, is_chief=True, job_name=None, task_index=None,
**kwargs):

if seed:
tf.set_random_seed(seed)
Expand Down Expand Up @@ -62,10 +62,6 @@ def run(model_type, config_file, override_params, continue_training, seed,
prediction_dict = model(train_image, train_bboxes, training=True)
total_loss = model.loss(prediction_dict)

# Load pretrained weights needs to be called before defining the train
# op. After it, variables for the optimizer are created.
load_pretrained_op = model.load_pretrained_weights()

global_step = tf.contrib.framework.get_or_create_global_step()

optimizer = get_optimizer(config.train, global_step)
Expand All @@ -84,12 +80,6 @@ def run(model_type, config_file, override_params, continue_training, seed,
grads_and_vars, global_step=global_step
)

init_op = tf.group(
tf.global_variables_initializer(),
# Queue-related variables need a special initializer.
tf.local_variables_initializer()
)

tf.logging.info('{}Starting training for {}'.format(log_prefix, model))

run_options = None
Expand All @@ -98,14 +88,19 @@ def run(model_type, config_file, override_params, continue_training, seed,
trace_level=tf.RunOptions.FULL_TRACE
)

# Load pretrained weights needs to be called before defining the train
# op. After it, variables for the optimizer are created.
with tf.control_dependencies([tf.global_variables_initializer()]):
with tf.control_dependencies([model.load_pretrained_weights()]):
init_op = tf.no_op(name='global_init_load_pretrained')

# Create custom Scaffold to make sure we run our own init_op when model
# is not restored from checkpoint.
scaffold = tf.train.Scaffold(
# Initialize local and global variables.
init_op=init_op,
# Load pretrained weights after init_op.
local_init_op=load_pretrained_op,
saver=model.get_saver(),
# Queue-related variables need a special initializer.
local_init_op=tf.local_variables_initializer(),
summary_op=tf.summary.merge([
tf.summary.merge_all(),
model.summary,
Expand Down Expand Up @@ -180,7 +175,6 @@ def run(model_type, config_file, override_params, continue_training, seed,
@click.option('model_type', '--model', required=True, default='fasterrcnn') # noqa
@click.option('config_file', '--config', '-c', help='Config to use.')
@click.option('override_params', '--override', '-o', multiple=True, help='Override model config params.') # noqa
@click.option('--continue-training', is_flag=True, help='Continue training using model dir and run name.') # noqa
@click.option('--seed', type=float, help='Global seed value for random operations.') # noqa
@click.option('--checkpoint-file', help='Weight checkpoint to resuming training from.') # noqa
@click.option('--ignore-scope', help='Used to ignore variables when loading from checkpoint.') # noqa
Expand Down

0 comments on commit 1193b81

Please sign in to comment.