From 2844eb22b63fecfec8ba98ee3f6fc5ac2940bb5a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Sep 2023 21:51:09 -0400 Subject: [PATCH] run eval on the first step to get a baseline (#617) * run eval on the first step to get a baseline * wandb kleeps getting moved around by pre-commit ... --- src/axolotl/utils/callbacks.py | 23 +++++++++++++++++++++++ src/axolotl/utils/trainer.py | 2 ++ 2 files changed, 25 insertions(+) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index bf8c4145bd..236caeea46 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -66,6 +66,29 @@ def on_save( return control +class EvalFirstStepCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods disable=unused-argument + """ + Callback to trigger evals on the first step + """ + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if ( + args.evaluation_strategy == IntervalStrategy.STEPS + and args.eval_steps < 1.0 + and state.global_step == 1 + ): + control.should_evaluate = True + return control + + class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2067a90069..944ac5f511 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -28,6 +28,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( + EvalFirstStepCallback, GPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, @@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ callbacks = [] callbacks.append(GPUStatsCallback(cfg)) + callbacks.append(EvalFirstStepCallback) if cfg.relora_steps: callbacks.append(ReLoRACallback(cfg))