Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Fixes Checkpointing #5220

Merged
merged 40 commits into from
May 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
fb6b00b
Removes unused variable
dirkgr May 19, 2021
7d15a73
Formatting
dirkgr May 19, 2021
1475729
Make sure we always restore the model's weights properly
dirkgr May 20, 2021
f247cdd
Give TrainerCallbacks the ability to save and load state dicts
dirkgr May 20, 2021
fd14dd7
Give MovingAverage the ability to save and load state dicts
dirkgr May 20, 2021
27d90d0
Do not set gradients to None
dirkgr May 20, 2021
7612741
Typo
dirkgr May 20, 2021
d894b25
Remove unused variable
dirkgr May 20, 2021
e0e917e
Typo
dirkgr May 20, 2021
e52e7ad
Entirely new checkpointing code
dirkgr May 24, 2021
bead35e
Formatting
dirkgr May 24, 2021
c5e8537
Merge remote-tracking branch 'origin/main' into Checkpointing
dirkgr May 24, 2021
6128a88
Make mypy happy
dirkgr May 24, 2021
cef052d
Makes the no-op trainer work with the new checkpointer
dirkgr May 24, 2021
fdc8db7
Mark epochs as completed when they're skipped
dirkgr May 24, 2021
b3b65c2
Changelog
dirkgr May 24, 2021
6c944c2
Fixes how we get the best weights after a training run
dirkgr May 24, 2021
0799137
Mypy is annoying
dirkgr May 24, 2021
3cae291
Callback fixes
dirkgr May 25, 2021
e04eae2
Fix the no op trainer
dirkgr May 25, 2021
d873df1
Simplify
dirkgr May 25, 2021
5480fd5
Assorted checkpointer fixes
dirkgr May 25, 2021
eafb48a
Mypy is now happy
dirkgr May 25, 2021
37c13a7
Fixed all the tests except for one
dirkgr May 25, 2021
872ac20
Removed unused variable
dirkgr May 25, 2021
b8545fd
Fix trainer restore logic
dirkgr May 25, 2021
273386d
Fix test for trainer restore logic
dirkgr May 25, 2021
abc7826
Merge remote-tracking branch 'origin/main' into Checkpointing
dirkgr May 25, 2021
69909ff
Check the Checkpointing branch of the models repo
dirkgr May 25, 2021
4ca4912
Help mypy along
dirkgr May 25, 2021
7af394f
Fixed finalizing logic
dirkgr May 25, 2021
302c672
More mypy stuff
dirkgr May 25, 2021
cd33ad2
Merge branch 'main' into Checkpointing
dirkgr May 25, 2021
153608b
Merge branch 'main' into Checkpointing
dirkgr May 27, 2021
9f97c38
Merge branch 'main' into Checkpointing
dirkgr May 27, 2021
f95d437
Merge branch 'main' into Checkpointing
dirkgr May 27, 2021
7811679
Merge branch 'main' into Checkpointing
dirkgr May 28, 2021
82353ed
Update allennlp/training/checkpointer.py
dirkgr May 29, 2021
5537194
Merge remote-tracking branch 'origin/main' into Checkpointing
dirkgr May 29, 2021
5aa9b00
Make weaker claims
dirkgr May 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Callback fixes
  • Loading branch information
dirkgr committed May 25, 2021
commit 3cae29107fd6240f8a4b279c94ead4c3751ba090
6 changes: 3 additions & 3 deletions allennlp/training/callbacks/log_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ def log_epoch(
def _should_log_distributions_next_batch(self) -> bool:
return (
self._distribution_interval is not None
and (self.trainer._batch_num_total + 1) % self._distribution_interval == 0 # type: ignore[union-attr]
and (self.trainer._total_batches_completed + 1) % self._distribution_interval == 0
)

def _should_log_distributions_this_batch(self) -> bool:
return (
self._distribution_interval is not None
and self.trainer._batch_num_total % self._distribution_interval == 0 # type: ignore[union-attr]
and self.trainer._total_batches_completed % self._distribution_interval == 0
)

def _enable_activation_logging(self) -> None:
Expand All @@ -318,7 +318,7 @@ def hook(module_, inputs, outputs):
self._module_hook_handles.append(module.register_forward_hook(hook))

def _should_log_this_batch(self) -> bool:
return self.trainer._batch_num_total % self._summary_interval == 0 # type: ignore[union-attr]
return self.trainer._total_batches_completed % self._summary_interval == 0 # type: ignore[union-attr]

def _log_activation_distribution(self, outputs: Any, module_name: str) -> None:
activations_to_log: Dict[str, torch.Tensor] = {}
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def log_scalars(
log_prefix: str = "",
epoch: Optional[int] = None,
) -> None:
timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr]
timestep = epoch if epoch is not None else self.trainer._total_batches_completed # type: ignore[union-attr]
log = self._train_log if not log_prefix.startswith("validation") else self._validation_log
for key, value in scalars.items():
name = f"{log_prefix}/{key}" if log_prefix else key
Expand All @@ -59,7 +59,7 @@ def log_scalars(
def log_tensors(
self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None
) -> None:
timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr]
timestep = epoch if epoch is not None else self.trainer._total_batches_completed # type: ignore[union-attr]
log = self._train_log if not log_prefix.startswith("validation") else self._validation_log
for key, values in tensors.items():
name = f"{log_prefix}/{key}" if log_prefix else key
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _log(
dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()}
if epoch is not None:
dict_to_log["epoch"] = epoch
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore[union-attr]
self.wandb.log(dict_to_log, step=self.trainer._total_batches_completed) # type: ignore[union-attr]

@overrides
def on_start(
Expand Down