Skip to content

Commit

Permalink
Add mlflow callback for pushing config to mlflow artifacts (#1125)
Browse files Browse the repository at this point in the history
* Update callbacks.py

adding callback for mlflow

* Update trainer_builder.py

* clean up
  • Loading branch information
JohanWork authored Jan 22, 2024
1 parent 8b17d30 commit 928d070
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
Expand Down Expand Up @@ -543,6 +544,10 @@ def get_callbacks(self):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow:
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)

if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
Expand Down
29 changes: 29 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Dict, List

import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -575,3 +576,31 @@ def on_train_begin(
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control


class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
"""Callback to save axolotl config to mlflow"""

def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path

def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

0 comments on commit 928d070

Please sign in to comment.