diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index bf8c4145bd..236caeea46 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -66,6 +66,29 @@ def on_save( return control +class EvalFirstStepCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods disable=unused-argument + """ + Callback to trigger evals on the first step + """ + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if ( + args.evaluation_strategy == IntervalStrategy.STEPS + and args.eval_steps < 1.0 + and state.global_step == 1 + ): + control.should_evaluate = True + return control + + class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2067a90069..944ac5f511 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -28,6 +28,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( + EvalFirstStepCallback, GPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, @@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ callbacks = [] callbacks.append(GPUStatsCallback(cfg)) + callbacks.append(EvalFirstStepCallback) if cfg.relora_steps: callbacks.append(ReLoRACallback(cfg))