Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Fixes Checkpointing #5220

Merged
merged 40 commits into from
May 29, 2021
Merged

Fixes Checkpointing #5220

merged 40 commits into from
May 29, 2021

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented May 24, 2021

  • Checkpointing and restarting from a checkpoint now works when the training job is interrupted half-way through an epoch.
  • The checkpointer is no longer responsible for writing out the current best model. The trainer has to do this now.
  • GradientDescentTrainer now lives in its own file. I had to do this to break a circular dependency between Checkpointer and GradientDescentTrainer.
  • Callbacks can now save and restore state.
  • When training with moving average, restoring checkpoints now works correctly.
  • When re-starting an interrupted training job, the trainer will now read out the data loader even for epochs and batches that can be skipped. This is necessary to ensure that any random number generators used by the reader or data loader are in the same state as they were the first time the training job ran.

@dirkgr
Copy link
Member Author

dirkgr commented May 24, 2021

@epwalsh, you can look at this now, while I'm fixing tests. What do we need to change to make fairscale work?

@epwalsh
Copy link
Member

epwalsh commented May 24, 2021

In a meeting now but I'll take a look afterwards

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are great improvements, but they don't really change the story with FairScale.

One thing that's missing is synchronization across distributed workers when gathering model and training state, since collecting the state associated with sharded parameters requires a distributed gather operation (each worker needs to send its shard of the data to the main process).

Another issue is that the optimizer state actually has to be collected through the FullyShardedDataParallel model wrapper (gather_full_optim_state).

Comment on lines 49 to 53
save_completed_epochs: bool = True,
save_every_num_seconds: Optional[int] = None,
save_every_num_batches: Optional[int] = None,
keep_most_recent_by_count: Optional[int] = 2,
keep_most_recent_by_age: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I hated the old names 💯

GradientDescentTrainer,
)
from allennlp.training.trainer import Trainer
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I've been wanting to move this to it's own file for a file.

Comment on lines +80 to +84
def state_dict(self) -> Dict[str, Any]:
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dirkgr
Copy link
Member Author

dirkgr commented May 24, 2021

Deep in the throes of fixing all the tests, I'm wondering if I should have fixed this backwards. Saving and restoring in the middle of an epoch was added to the checkpointer, but it's completely unsupported by any other part of the system. This is essentially a new piece of functionality.

@dirkgr dirkgr marked this pull request as ready for review May 25, 2021 22:52
@dirkgr
Copy link
Member Author

dirkgr commented May 25, 2021

Tests pass locally. I'm still fighting with mypy and the models repo. We might have to retrain some stuff (or at least patch the model configs), because the num_serialized_models_to_keep parameter went away.

But overall, this is ready to review.

@@ -152,6 +152,7 @@ jobs:
run: |
git clone https://github.com/allenai/allennlp-models.git
cd allennlp-models
git checkout Checkpointing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
git checkout Checkpointing

This will have to be removed before merging

@dirkgr
Copy link
Member Author

dirkgr commented May 25, 2021

GradientDescentTrainer is by and large the same. While reviewing, only look at the bits that have to do with checkpointing, and the _start_after_* variables.

You can also review this one commit at a time. I kept the commits pretty clean and self contained. That'll let you skip the big copy of GradientDescentTrainer.

@dirkgr
Copy link
Member Author

dirkgr commented May 25, 2021

We should do a minor version bump after this. It changes some public APIs.

@dirkgr
Copy link
Member Author

dirkgr commented May 27, 2021

@epwalsh, this is ready for a real review now.

@dirkgr dirkgr self-assigned this May 27, 2021
Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. I just left a few comments.

Comment on lines +134 to +138
extra_copy_of_weights_just_for_mypy = Path(weights)
if extra_copy_of_weights_just_for_mypy.is_absolute():
weights_file = extra_copy_of_weights_just_for_mypy
else:
weights_file = Path(serialization_dir) / extra_copy_of_weights_just_for_mypy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little confusing. How about just use typing.cast?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serialization_dir can be a str at the time. It's not just to let mypy know what it is.

Comment on lines +28 to +41
save_completed_epochs : `bool`, (default=`True`)
Saves model and trainer state at the end of each completed epoch.
save_every_num_seconds : `int`, optional (default=`None`)
If set, makes sure we never go longer than this number of seconds between saving a model.
save_every_num_batches : `int`, optional (default=`None`)
If set, makes sure we never go longer than this number of batches between saving a model.
keep_most_recent_by_count : `int`, optional (default=`2`)
Sets the number of model checkpoints to keep on disk. If both `keep_most_recent_by_count` and
`keep_most_recent_by_age` are set, we'll keep checkpoints that satisfy either criterion.
If both are `None`, we keep all checkpoints.
keep_most_recent_by_age : `int`, optional (default=`None`)
Sets the number of seconds we'll keep a checkpoint before deleting it. If both
`keep_most_recent_by_count` and `keep_most_recent_by_age` are set, we'll keep checkpoints
that satisfy either criterion. If both are `None`, we keep all checkpoints.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is much more clear.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it breaks backwards compatibility. Worth it, I think, but not great.

allennlp/training/checkpointer.py Outdated Show resolved Hide resolved
CHANGELOG.md Outdated
@@ -40,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.
- Fixed documentation for `GradientDescentTrainer.cuda_device`.
- Re-starting a training run from a checkpoint in the middle of an epoch now works correctly.
- When using the "moving average" weights smoothing feature of the trainer, training checkpoints would also get smoothed, with strange results for resuming a training job. This has been fixed.
- When re-starting an interrupted training job, the trainer will now read out the data loader even for epochs and batches that can be skipped. This ensures that any random number generators used by the reader or data loader are in the same state as they were the first time the training job ran.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good, in theory, but there are probably other things that affect the random number generators used by the reader and data loader. I don't think we can guarantee the same order.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I wrote it this way because in Quark it worked out that way. I had good enough control over the RNGs that it was deterministic.

In AllenNLP, we can't guarantee that none of the things we're skipping when restoring from a checkpoint (the forward() method for example) modify the RNG state. I guess I'll say that this is an attempt to ensure deterministic randomness, but does not guarantee it. At the same time, we should encourage components to use their own RNG instead of using the global one, so they don't affect each other.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually quite bad if this doesn't work. If we don't guarantee the order of instances, and we stop training 10 times in the middle of an epoch and restart it, we might end up training on the same instance 10 times.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 times even

@dirkgr dirkgr enabled auto-merge (squash) May 29, 2021 02:03
@dirkgr dirkgr merged commit c5bff8b into main May 29, 2021
@dirkgr dirkgr deleted the Checkpointing branch May 29, 2021 02:18
Abhishek-P pushed a commit to Abhishek-P/allennlp that referenced this pull request Aug 11, 2021
* Removes unused variable

* Formatting

* Make sure we always restore the model's weights properly

* Give TrainerCallbacks the ability to save and load state dicts

* Give MovingAverage the ability to save and load state dicts

* Do not set gradients to None

* Typo

* Remove unused variable

* Typo

* Entirely new checkpointing code

* Formatting

* Make mypy happy

lol

* Makes the no-op trainer work with the new checkpointer

* Mark epochs as completed when they're skipped

* Changelog

* Fixes how we get the best weights after a training run

* Mypy is annoying

* Callback fixes

* Fix the no op trainer

* Simplify

* Assorted checkpointer fixes

* Mypy is now happy

* Fixed all the tests except for one

* Removed unused variable

* Fix trainer restore logic

* Fix test for trainer restore logic

* Check the Checkpointing branch of the models repo

* Help mypy along

* Fixed finalizing logic

* More mypy stuff

* Update allennlp/training/checkpointer.py

Co-authored-by: Pete <petew@allenai.org>

* Make weaker claims

Co-authored-by: Pete <petew@allenai.org>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants