Skip to content

Commit

Permalink
Save scaler state dict when checkpointing (huggingface#11663)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed May 10, 2021
1 parent ef8d32c commit 05a9306
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,12 +1480,16 @@ def _save_checkpoint(self, model, trial, metrics=None):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
Expand Down Expand Up @@ -1569,6 +1573,8 @@ def _load_optimizer_and_scheduler(self, checkpoint):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
if self.use_amp and os.path.isfile(os.path.join(checkpoint, "scaler.pt")):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, "scaler.pt")))

def hyperparameter_search(
self,
Expand Down

0 comments on commit 05a9306

Please sign in to comment.