Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix model loading on GPU post training (#5518)
Browse files Browse the repository at this point in the history
* Fixes #5511: Load model on CPU post training to save GPU memory

* Edited Changelog
  • Loading branch information
vikigenius authored Dec 20, 2021
1 parent 3552842 commit ab4f7b5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `FBetaMultiLabelMeasure` now works with multiple dimensions
- Support for inferior operating systems when making hardlinks
- Use `,` as a separator for filenames in the `evaluate` command, thus allowing for URLs (eg. `gs://...`) as input files.
- Load model on CPU post training to save GPU memory.

### Removed

Expand Down Expand Up @@ -51,7 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added in a default behavior to the `_to_params` method of `Registrable` so that in the case it is not implemented by the child class, it will still produce _a parameter dictionary_.
- Added in a default behavior to the `_to_params` method of `Registrable` so that in the case it is not implemented by the child class, it will still produce _a parameter dictionary_.
- Added in `_to_params` implementations to all tokenizers.
- Added support to evaluate mutiple datasets and produce corresponding output files in the `evaluate` command.
- Added more documentation to the learning rate schedulers to include a sample config object for how to use it.
Expand Down
6 changes: 4 additions & 2 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,10 +965,12 @@ def _save_model_state(self, path: str) -> None:
torch.save(self.model.state_dict(), path)

def _load_model_state(self, path: str) -> None:
# This function is only called after training. So load model on the CPU.
device = torch.device("cpu")
if self._ddp_wrapped_model is not None:
self._ddp_wrapped_model.load_state_dict(torch.load(path))
self._ddp_wrapped_model.load_state_dict(torch.load(path, map_location=device))
else:
self._pytorch_model.load_state_dict(torch.load(path))
self._pytorch_model.load_state_dict(torch.load(path, map_location=device))

def _finalize_model(self) -> None:
"""If we have a moving average, we have to finalize the model at the end of training."""
Expand Down

0 comments on commit ab4f7b5

Please sign in to comment.