diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b420144ba1..29c7005a3b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed a spurious error message "'torch.cuda' has no attribute '_check_driver'" that would be appear in the logs when a `ConfigurationError` for missing GPU was raised. - Load model on CPU post training to save GPU memory. +- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring after the first epoch regardless of `validation_start` value. +- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring every `validation_interval + 1` epochs, instead of every `validation_interval` epochs. ### Removed diff --git a/allennlp/training/callbacks/should_validate.py b/allennlp/training/callbacks/should_validate.py index a1008fc70a6..9ed801523d1 100644 --- a/allennlp/training/callbacks/should_validate.py +++ b/allennlp/training/callbacks/should_validate.py @@ -25,6 +25,11 @@ def __init__( self._validation_start = validation_start self._validation_interval = validation_interval + def on_start( + self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs + ) -> None: + trainer._should_validate_this_epoch = self._should_validate(epoch=0) + def on_epoch( self, trainer: "GradientDescentTrainer", @@ -33,9 +38,12 @@ def on_epoch( is_primary: bool = True, **kwargs, ) -> None: + trainer._should_validate_this_epoch = self._should_validate(epoch=epoch + 1) + + def _should_validate(self, epoch: int) -> bool: + should_validate = True if self._validation_start is not None and epoch < self._validation_start: - trainer._should_validate_this_epoch = False + should_validate = False elif self._validation_interval is not None and epoch % self._validation_interval != 0: - trainer._should_validate_this_epoch = False - else: - trainer._should_validate_this_epoch = True + should_validate = False + return should_validate diff --git a/tests/training/trainer_test.py b/tests/training/trainer_test.py index 576a61d2183..bad84580000 100644 --- a/tests/training/trainer_test.py +++ b/tests/training/trainer_test.py @@ -1361,16 +1361,20 @@ def test_should_validate_callback(self): ) trainer.train() - # Doesn't satisfy 'validation_start' or 'validation_interval' + # Shouldn't validate on the first epoch as it's before the 'validation_start' + callback.on_start(trainer) + assert not trainer._should_validate_this_epoch + + # Satisfies 'validation_interval' but not 'validation_start' callback.on_epoch(trainer, metrics={}, epoch=1) assert not trainer._should_validate_this_epoch - # Satisfies 'validation_start' but not 'validation_interval' + # Doesn't satisfy 'validation_start' or 'validation_interval' callback.on_epoch(trainer, metrics={}, epoch=2) assert not trainer._should_validate_this_epoch # Satisfies both 'validation_start' and 'validation_interval' - callback.on_epoch(trainer, metrics={}, epoch=4) + callback.on_epoch(trainer, metrics={}, epoch=5) assert trainer._should_validate_this_epoch