Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
make W&B callback resumable (#5312)
Browse files Browse the repository at this point in the history
Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
epwalsh and dirkgr authored Jul 19, 2021
1 parent 9629340 commit ef5400d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the docs for `PytorchTransformerWrapper`
- Fixed recovering training jobs with models that expect `get_metrics()` to not be called until they have seen at least one batch.
- Made the Transformer Toolkit compatible with transformers that don't start their positional embeddings at 0.
- Weights & Biases training callback ("wandb") now works when resuming training jobs.

### Changed

Expand Down
16 changes: 15 additions & 1 deletion allennlp/training/callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(

self._watch_model = watch_model
self._files_to_save = files_to_save
self._run_id: Optional[str] = None
self._wandb_kwargs: Dict[str, Any] = dict(
dir=os.path.abspath(serialization_dir),
project=project,
Expand Down Expand Up @@ -141,7 +142,11 @@ def on_start(
import wandb

self.wandb = wandb
self.wandb.init(**self._wandb_kwargs)

if self._run_id is None:
self._run_id = self.wandb.util.generate_id()

self.wandb.init(id=self._run_id, **self._wandb_kwargs)

for fpath in self._files_to_save:
self.wandb.save( # type: ignore
Expand All @@ -155,3 +160,12 @@ def on_start(
def close(self) -> None:
super().close()
self.wandb.finish() # type: ignore

def state_dict(self) -> Dict[str, Any]:
return {
"run_id": self._run_id,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._wandb_kwargs["resume"] = "auto"
self._run_id = state_dict["run_id"]

0 comments on commit ef5400d

Please sign in to comment.