From 383f88d7a71286e5c74f50e3caf893bd9d071fcf Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 28 Sep 2023 10:14:41 +0900 Subject: [PATCH] Fix(cfg): Add validation for save_strategy and eval_strategy (#633) * Fix(cfg): Check save_strategy cfg conflict with save_steps * Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps * chore: add extra check for steps only --- src/axolotl/utils/config.py | 18 ++++ src/axolotl/utils/trainer.py | 15 +--- tests/test_validation.py | 168 +++++++++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1dfdab2605..ac067b5055 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -296,6 +296,24 @@ def validate_config(cfg): cfg.datasets[idx].type = cfg.datasets[idx].type.replace( "sharegpt_simple", "sharegpt" ) + if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) + + if ( + cfg.evaluation_strategy + and cfg.eval_steps + and cfg.evaluation_strategy != "steps" + ): + raise ValueError( + "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps." + ) + + if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy): + raise ValueError( + "eval_steps and evaluation_strategy are not supported with val_set_size == 0" + ) # TODO # MPT 7b diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index aee2a1b99e..3c75e4ec53 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -604,26 +604,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ "sample_packing_efficiency" ] = cfg.sample_packing_eff_est - if cfg.eval_steps and cfg.evaluation_strategy: - # assume if the user set both, they know what they're doing - training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy + if cfg.eval_steps: + training_arguments_kwargs["evaluation_strategy"] = "steps" training_arguments_kwargs["eval_steps"] = cfg.eval_steps + elif cfg.evaluation_strategy: + training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy elif cfg.val_set_size == 0: # no eval set, so don't eval training_arguments_kwargs["evaluation_strategy"] = "no" - elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]: - # if explicitly set for epoch, just set, and eval steps don't matter - training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy - elif cfg.eval_steps: - # steps isn't used w/ epochs - training_arguments_kwargs["evaluation_strategy"] = "steps" - training_arguments_kwargs["eval_steps"] = cfg.eval_steps else: # we have an eval set, but no steps defined, default to use epoch training_arguments_kwargs["evaluation_strategy"] = "epoch" if cfg.save_steps: - # save_steps implies save_strategy of steps training_arguments_kwargs["save_strategy"] = "steps" training_arguments_kwargs["save_steps"] = cfg.save_steps elif cfg.save_strategy: diff --git a/tests/test_validation.py b/tests/test_validation.py index b9a57c2e9a..35d90a2cb4 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -397,3 +397,171 @@ def test_sharegpt_deprecation(self): for record in self._caplog.records ) assert cfg.datasets[0].type == "sharegpt:load_role" + + def test_no_conflict_save_strategy(self): + cfg = DictDefault( + { + "save_strategy": "epoch", + "save_steps": 10, + } + ) + + with pytest.raises( + ValueError, match=r".*save_strategy and save_steps mismatch.*" + ): + validate_config(cfg) + + cfg = DictDefault( + { + "save_strategy": "no", + "save_steps": 10, + } + ) + + with pytest.raises( + ValueError, match=r".*save_strategy and save_steps mismatch.*" + ): + validate_config(cfg) + + cfg = DictDefault( + { + "save_strategy": "steps", + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "save_strategy": "steps", + "save_steps": 10, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "save_steps": 10, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "save_strategy": "no", + } + ) + + validate_config(cfg) + + def test_no_conflict_eval_strategy(self): + cfg = DictDefault( + { + "evaluation_strategy": "epoch", + "eval_steps": 10, + } + ) + + with pytest.raises( + ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" + ): + validate_config(cfg) + + cfg = DictDefault( + { + "evaluation_strategy": "no", + "eval_steps": 10, + } + ) + + with pytest.raises( + ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" + ): + validate_config(cfg) + + cfg = DictDefault( + { + "evaluation_strategy": "steps", + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "evaluation_strategy": "steps", + "eval_steps": 10, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "eval_steps": 10, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "evaluation_strategy": "no", + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "evaluation_strategy": "epoch", + "val_set_size": 0, + } + ) + + with pytest.raises( + ValueError, + match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "eval_steps": 10, + "val_set_size": 0, + } + ) + + with pytest.raises( + ValueError, + match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "val_set_size": 0, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "eval_steps": 10, + "val_set_size": 0.01, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "evaluation_strategy": "epoch", + "val_set_size": 0.01, + } + ) + + validate_config(cfg)