diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 18dc353a23..9ad6c464f0 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,6 +28,7 @@ EvalFirstStepCallback, GPUStatsCallback, LossWatchDogCallback, + SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, @@ -542,6 +543,10 @@ def get_callbacks(self): callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_mlflow: + callbacks.append( + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + ) if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 122cd92ede..9de266be10 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Dict, List import evaluate +import mlflow import numpy as np import pandas as pd import torch @@ -575,3 +576,31 @@ def on_train_begin( except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control + + +class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): + """Callback to save axolotl config to mlflow""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + mlflow.log_artifact(temp_file.name, artifact_path="") + LOG.info( + "The Axolotl config has been saved to the MLflow artifacts." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") + return control