From 29cf15a28cfa8eb310af55b837e9cc56c9b2a571 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 4 May 2024 23:19:18 -0400 Subject: [PATCH] improve save callbacks (#1592) --- src/axolotl/core/trainer_builder.py | 22 +++++++++++++--------- src/axolotl/utils/callbacks/__init__.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index bf18a287a8..742a88633a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -43,6 +43,7 @@ LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, + SaveModelOnTrainEndCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, @@ -888,6 +889,14 @@ def get_callbacks(self) -> List[TrainerCallback]: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_mlflow and is_mlflow_available(): + from axolotl.utils.callbacks.mlflow_ import ( + SaveAxolotlConfigtoMlflowCallback, + ) + + callbacks.append( + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + ) return callbacks @@ -933,18 +942,11 @@ def get_callbacks(self): ): callbacks.append(SaveBetterTransformerModelCallback()) - if self.cfg.use_mlflow and is_mlflow_available(): - from axolotl.utils.callbacks.mlflow_ import ( - SaveAxolotlConfigtoMlflowCallback, - ) - - callbacks.append( - SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) - ) - if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) + callbacks.append(SaveModelOnTrainEndCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -1427,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() + callbacks.append(SaveModelOnTrainEndCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index fbc1dcfad8..e66f165a53 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -773,3 +773,13 @@ def on_train_begin( except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control + + +class SaveModelOnTrainEndCallback(TrainerCallback): + """Callback to save model on train end""" + + def on_train_end( # pylint: disable=unused-argument + self, args, state, control, **kwargs + ): + control.should_save = True + return control