Skip to content

Commit

Permalink
improve save callbacks (#1592)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored May 5, 2024
1 parent dde02fc commit 29cf15a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
Expand Down Expand Up @@ -888,6 +889,14 @@ def get_callbacks(self) -> List[TrainerCallback]:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)

callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)

return callbacks

Expand Down Expand Up @@ -933,18 +942,11 @@ def get_callbacks(self):
):
callbacks.append(SaveBetterTransformerModelCallback())

if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)

callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)

if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))

callbacks.append(SaveModelOnTrainEndCallback())

return callbacks

def get_post_trainer_create_callbacks(self, trainer):
Expand Down Expand Up @@ -1427,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):

def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback())

return callbacks

def get_post_trainer_create_callbacks(self, trainer):
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,13 @@ def on_train_begin(
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control


class SaveModelOnTrainEndCallback(TrainerCallback):
"""Callback to save model on train end"""

def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control

0 comments on commit 29cf15a

Please sign in to comment.