Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Fix ShouldValidateCallback #5536

Merged
merged 5 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed a spurious error message "'torch.cuda' has no attribute '_check_driver'" that would be appear in the logs
when a `ConfigurationError` for missing GPU was raised.
- Load model on CPU post training to save GPU memory.
- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring after the first epoch regardless of `validation_start` value.
- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring every `validation_interval + 1` epochs, instead of every `validation_interval` epochs.

### Removed

Expand Down
16 changes: 12 additions & 4 deletions allennlp/training/callbacks/should_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def __init__(
self._validation_start = validation_start
self._validation_interval = validation_interval

def on_start(
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs
) -> None:
trainer._should_validate_this_epoch = self._should_validate(epoch=0)

def on_epoch(
self,
trainer: "GradientDescentTrainer",
Expand All @@ -33,9 +38,12 @@ def on_epoch(
is_primary: bool = True,
**kwargs,
) -> None:
trainer._should_validate_this_epoch = self._should_validate(epoch=epoch + 1)

def _should_validate(self, epoch: int) -> bool:
should_validate = True
if self._validation_start is not None and epoch < self._validation_start:
trainer._should_validate_this_epoch = False
should_validate = False
elif self._validation_interval is not None and epoch % self._validation_interval != 0:
trainer._should_validate_this_epoch = False
else:
trainer._should_validate_this_epoch = True
should_validate = False
return should_validate
10 changes: 7 additions & 3 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,16 +1361,20 @@ def test_should_validate_callback(self):
)
trainer.train()

# Doesn't satisfy 'validation_start' or 'validation_interval'
# Shouldn't validate on the first epoch as it's before the 'validation_start'
callback.on_start(trainer)
assert not trainer._should_validate_this_epoch

# Satisfies 'validation_interval' but not 'validation_start'
callback.on_epoch(trainer, metrics={}, epoch=1)
assert not trainer._should_validate_this_epoch

# Satisfies 'validation_start' but not 'validation_interval'
# Doesn't satisfy 'validation_start' or 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=2)
assert not trainer._should_validate_this_epoch

# Satisfies both 'validation_start' and 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=4)
callback.on_epoch(trainer, metrics={}, epoch=5)
assert trainer._should_validate_this_epoch


Expand Down