Skip to content

Commit

Permalink
Fix debug mode, now done via hook.
Browse files Browse the repository at this point in the history
  • Loading branch information
dekked authored and vierja committed Sep 16, 2017
1 parent 0c822fd commit 8fc6eeb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
3 changes: 2 additions & 1 deletion luminoth/models/pretrained/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def load_weights(self, checkpoint_file):
)

tf.logging.debug(
'Loading {} variables from pretrained checkpoint {}'.format(
'Constructing op to load {} variables from pretrained '
'checkpoint {}'.format(
len(load_variables), checkpoint_file
))

Expand Down
23 changes: 16 additions & 7 deletions luminoth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import json
import tensorflow as tf

from tensorflow.python import debug as tf_debug

from luminoth.datasets import TFRecordDataset
from luminoth.models import get_model
from luminoth.utils.config import (
Expand Down Expand Up @@ -113,6 +115,8 @@ def run(target, cluster_spec, is_chief, model_type, config_file,
# TODO: Make optional for different types of models.
load_op = model.load_pretrained_weights()

# TODO: what is this? probably broken since code changes for
# distributed training
# saver = model.get_saver()
# if config.train.ignore_scope:
# partial_loader = model.get_saver(
Expand Down Expand Up @@ -182,25 +186,30 @@ def run(target, cluster_spec, is_chief, model_type, config_file,
# is not restored from checkpoint.
scaffold = tf.train.Scaffold(init_op=init_op)

#
# Custom hooks for our session
#
hooks = []
if config.train.tf_debug:
debug_hook = tf_debug.LocalCLIDebugHook()
debug_hook.add_tensor_filter(
'has_inf_or_nan', tf_debug.has_inf_or_nan
)
hooks.extend([debug_hook])

with tf.train.MonitoredTrainingSession(
master=target,
is_chief=is_chief,
checkpoint_dir=config.train.job_dir,
scaffold=scaffold,
hooks=hooks,
save_checkpoint_secs=config.train.save_checkpoint_secs,
save_summaries_steps=config.train.save_summaries_steps,
save_summaries_secs=config.train.save_summaries_secs,
) as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

if config.train.tf_debug:
from tensorflow.python import debug as tensorflow_debug
sess = tensorflow_debug.LocalCLIDebugWrapperSession(sess)
sess.add_tensor_filter(
'has_inf_or_nan', tensorflow_debug.has_inf_or_nan
)

try:
while not coord.should_stop():
_, train_loss, step, filename = sess.run([
Expand Down

0 comments on commit 8fc6eeb

Please sign in to comment.