diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index e66f165a53..2965ac1e29 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -778,6 +778,17 @@ def on_train_begin( class SaveModelOnTrainEndCallback(TrainerCallback): """Callback to save model on train end""" + def on_step_end( # pylint: disable=unused-argument + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Save + if state.global_step >= state.max_steps: + control.should_save = True + def on_train_end( # pylint: disable=unused-argument self, args, state, control, **kwargs ):