diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 8716cb93b..a054f24a7 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -383,9 +383,9 @@ def legacy_validate_config(cfg): "push_to_hub_model_id is deprecated. Please use hub_model_id instead." ) - if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch): + if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]: LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + "hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty." ) if cfg.gptq and cfg.revision_of_model: @@ -448,10 +448,14 @@ def legacy_validate_config(cfg): raise ValueError( "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." ) - if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps": + if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps": raise ValueError( "save_strategy must be empty or set to `steps` when used with saves_per_epoch." ) + if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) if cfg.evals_per_epoch and cfg.eval_steps: raise ValueError( "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." @@ -464,11 +468,6 @@ def legacy_validate_config(cfg): raise ValueError( "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." ) - if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - if ( cfg.evaluation_strategy and cfg.eval_steps diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index e27a8ddd5..72e82e823 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -780,11 +780,11 @@ def check_saves(cls, data): @model_validator(mode="before") @classmethod def check_push_save(cls, data): - if data.get("hub_model_id") and not ( - data.get("save_steps") or data.get("saves_per_epoch") + if data.get("hub_model_id") and ( + data.get("save_strategy") not in ["steps", "epoch", None] ): LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + "hub_model_id is set without any models being saved. To save a model, set save_strategy." ) return data diff --git a/tests/test_validation.py b/tests/test_validation.py index 4865712c4..27824f288 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1067,17 +1067,51 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): ): validate_config(cfg) - def test_hub_model_id_save_value_warns(self, minimal_cfg): - cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg + def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg) - assert ( - "set without any models being saved" in self._caplog.records[0].message - ) + assert len(self._caplog.records) == 1 + + def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 1 + + def test_hub_model_id_save_value_steps(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "steps"}) + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_epochs(self, minimal_cfg): + cfg = ( + DictDefault({"hub_model_id": "test", "save_strategy": "epoch"}) + | minimal_cfg + ) - def test_hub_model_id_save_value(self, minimal_cfg): - cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_none(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + + def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): + cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg with self._caplog.at_level(logging.WARNING): validate_config(cfg)