From 928d070093b95079ae57c8ec63e2b88bb7c88e64 Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Tue, 23 Jan 2024 00:44:39 +0100 Subject: [PATCH] Add mlflow callback for pushing config to mlflow artifacts (#1125) * Update callbacks.py adding callback for mlflow * Update trainer_builder.py * clean up --- src/axolotl/core/trainer_builder.py | 5 +++++ src/axolotl/utils/callbacks.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index bb027acf22..c3b01e6c6e 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,6 +28,7 @@ EvalFirstStepCallback, GPUStatsCallback, LossWatchDogCallback, + SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, @@ -543,6 +544,10 @@ def get_callbacks(self): callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_mlflow: + callbacks.append( + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + ) if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 122cd92ede..9de266be10 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Dict, List import evaluate +import mlflow import numpy as np import pandas as pd import torch @@ -575,3 +576,31 @@ def on_train_begin( except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control + + +class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): + """Callback to save axolotl config to mlflow""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + mlflow.log_artifact(temp_file.name, artifact_path="") + LOG.info( + "The Axolotl config has been saved to the MLflow artifacts." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") + return control