Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xlnwel committed Sep 26, 2019
1 parent 27c4fa5 commit ad1353f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions basic_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def __init__(self,
self.saver = self._setup_saver()
self.model_file = self._setup_model_path(args['model_root_dir'], self.model_name)

pwc(f'{self.name} has been constructed', 'magenta')

self.print_construction_complete()
@property
def global_variables(self):
return super().global_variables + self.graph.get_collection(name=tf.GraphKeys.GLOBAL_VARIABLES, scope='stats')
Expand Down Expand Up @@ -269,6 +269,9 @@ def log_tabular(self, key, value):
def dump_tabular(self, print_terminal_info=True):
self.logger.dump_tabular(print_terminal_info=print_terminal_info)

def print_construction_complete(self):
pwc(f'{self.name} has been constructed', 'cyan')

""" Implementation """
def _setup_saver(self):
return tf.train.Saver(self.global_variables)
Expand Down Expand Up @@ -336,9 +339,11 @@ def _record_stats_impl(self, kwargs):
del kwargs['worker_no']

# if global_step appeas in kwargs, use it when adding summary to tensorboard
step = kwargs['global_step'] if 'global_step' in kwargs else None
del kwargs['global_step']

if 'global_steps' in kwargs:
step = kwargs['global_step']
del kwargs['global_step']
else:
step = None
feed_dict = {}

for k, v in kwargs.items():
Expand Down

0 comments on commit ad1353f

Please sign in to comment.