From 460cff1d5bed7fa26de937dff33cdb9064961b7e Mon Sep 17 00:00:00 2001 From: Johan Hansson <39947546+JohanWork@users.noreply.github.com> Date: Mon, 15 Jan 2024 16:11:33 +0100 Subject: [PATCH 1/3] Update callbacks.py adding callback for mlflow --- src/axolotl/utils/callbacks.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 122cd92ede..eedf9738a0 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -14,6 +14,7 @@ import torch import torch.distributed as dist import wandb +import mlflow from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm @@ -575,3 +576,31 @@ def on_train_begin( except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control + + +class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): + """Callback to save axolotl config to mlflow""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + mlflow.log_artifact(temp_file.name, artifact_path='') + LOG.info( + "The Axolotl config has been saved to the MLflow artifacts." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") + return control From a1c3128b1417ba2053f5008571e0cb3627f0d186 Mon Sep 17 00:00:00 2001 From: Johan Hansson <39947546+JohanWork@users.noreply.github.com> Date: Mon, 15 Jan 2024 16:13:15 +0100 Subject: [PATCH 2/3] Update trainer_builder.py --- src/axolotl/core/trainer_builder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 18dc353a23..32597beaa3 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -32,6 +32,7 @@ SaveBetterTransformerModelCallback, bench_eval_callback_factory, log_prediction_callback_factory, + SaveAxolotlConfigtoMlflowCallback, ) from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -542,6 +543,11 @@ def get_callbacks(self): callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_mlflow: + callbacks.append( + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + ) + if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) From d3716ced9a398213a43698ad42adde0d259ca79a Mon Sep 17 00:00:00 2001 From: Johan Hansson <39947546+JohanWork@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:22:11 +0000 Subject: [PATCH 3/3] clean up --- src/axolotl/core/trainer_builder.py | 3 +-- src/axolotl/utils/callbacks.py | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 32597beaa3..9ad6c464f0 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,11 +28,11 @@ EvalFirstStepCallback, GPUStatsCallback, LossWatchDogCallback, + SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, log_prediction_callback_factory, - SaveAxolotlConfigtoMlflowCallback, ) from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -547,7 +547,6 @@ def get_callbacks(self): callbacks.append( SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) ) - if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index eedf9738a0..9de266be10 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -9,12 +9,12 @@ from typing import TYPE_CHECKING, Dict, List import evaluate +import mlflow import numpy as np import pandas as pd import torch import torch.distributed as dist import wandb -import mlflow from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm @@ -590,17 +590,17 @@ def on_train_begin( state: TrainerState, # pylint: disable=unused-argument control: TrainerControl, **kwargs, # pylint: disable=unused-argument - ): + ): if is_main_process(): try: - with NamedTemporaryFile( + with NamedTemporaryFile( mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" ) as temp_file: copyfile(self.axolotl_config_path, temp_file.name) - mlflow.log_artifact(temp_file.name, artifact_path='') + mlflow.log_artifact(temp_file.name, artifact_path="") LOG.info( - "The Axolotl config has been saved to the MLflow artifacts." - ) + "The Axolotl config has been saved to the MLflow artifacts." + ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") return control