Skip to content

Commit

Permalink
move early stopping callback after the benchmark evals
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 8, 2023
1 parent 8b87631 commit 687d343
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))

# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)

if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
Expand Down Expand Up @@ -725,4 +718,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)

return trainer

0 comments on commit 687d343

Please sign in to comment.