From 2b87bde0edcef0da685f16d94656ec01b2987317 Mon Sep 17 00:00:00 2001 From: Dave Farago Date: Tue, 9 Apr 2024 18:39:23 +0200 Subject: [PATCH 1/4] WIP: Support table logging for mlflow, too Create a `LogPredictionCallback` for both "wandb" and "mlflow" if specified. In `log_prediction_callback_factory`, create a generic table and make it specific only if the newly added `logger` argument is set to "wandb" resp. "mlflow". See https://github.com/OpenAccess-AI-Collective/axolotl/issues/1505 --- src/axolotl/core/trainer_builder.py | 7 ++++- src/axolotl/utils/callbacks/__init__.py | 39 +++++++++++++------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8d08d60b36..e3a1c1cf6c 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -943,7 +943,12 @@ def get_post_trainer_create_callbacks(self, trainer): callbacks = [] if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer + trainer, self.tokenizer, "wandb" + ) + callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.use_mlflow and is_mlflow_available() and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 6a489f6c0e..7dd9402999 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Dict, List import evaluate +import mlflow import numpy as np import pandas as pd import torch @@ -28,6 +29,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import ( barrier, broadcast_dict, @@ -540,7 +542,7 @@ def predict_with_generate(): return CausalLMBenchEvalCallback -def log_prediction_callback_factory(trainer: Trainer, tokenizer): +def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -597,15 +599,13 @@ def find_ranges(lst): return ranges def log_table_from_dataloader(name: str, table_dataloader): - table = wandb.Table( # type: ignore[attr-defined] - columns=[ - "id", - "Prompt", - "Correct Completion", - "Predicted Completion (model.generate)", - "Predicted Completion (trainer.prediction_step)", - ] - ) + table_data = { + "id": [], + "Prompt": [], + "Correct Completion": [], + "Predicted Completion (model.generate)": [], + "Predicted Completion (trainer.prediction_step)": [], + } row_index = 0 for batch in tqdm(table_dataloader): @@ -709,16 +709,17 @@ def log_table_from_dataloader(name: str, table_dataloader): ) in zip( prompt_texts, completion_texts, predicted_texts, pred_step_texts ): - table.add_data( - row_index, - prompt_text, - completion_text, - prediction_text, - pred_step_text, - ) + table_data["id"].append(row_index) + table_data["Prompt"].append(prompt_text) + table_data["Correct Completion"].append(completion_text) + table_data["Predicted Completion (model.generate)"].append(prediction_text) + table_data["Predicted Completion (trainer.prediction_step)"].append(pred_step_text) row_index += 1 - - wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined] + if logger == "wandb": + wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] + elif logger == "mlflow": + tracking_uri = AxolotlInputConfig(**self.cfg.to_dict()).mlflow_tracking_uri + mlflow.log_table(data=table_data, artifact_file="PredictionsVsGroundTruth.json", tracking_uri = tracking_uri) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader) From 934c4f9a9ff4552c05dcb325af052c6394e05ca0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Apr 2024 15:55:48 -0400 Subject: [PATCH 2/4] chore: lint --- src/axolotl/core/trainer_builder.py | 6 +++++- src/axolotl/utils/callbacks/__init__.py | 24 +++++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e3a1c1cf6c..8bdcfba268 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -946,7 +946,11 @@ def get_post_trainer_create_callbacks(self, trainer): trainer, self.tokenizer, "wandb" ) callbacks.append(LogPredictionCallback(self.cfg)) - if self.cfg.use_mlflow and is_mlflow_available() and self.cfg.eval_table_size > 0: + if ( + self.cfg.use_mlflow + and is_mlflow_available() + and self.cfg.eval_table_size > 0 + ): LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "mlflow" ) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 7dd9402999..ba2f439b27 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -6,7 +6,7 @@ import os from shutil import copyfile from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List import evaluate import mlflow @@ -599,7 +599,7 @@ def find_ranges(lst): return ranges def log_table_from_dataloader(name: str, table_dataloader): - table_data = { + table_data: Dict[str, List[Any]] = { "id": [], "Prompt": [], "Correct Completion": [], @@ -712,14 +712,24 @@ def log_table_from_dataloader(name: str, table_dataloader): table_data["id"].append(row_index) table_data["Prompt"].append(prompt_text) table_data["Correct Completion"].append(completion_text) - table_data["Predicted Completion (model.generate)"].append(prediction_text) - table_data["Predicted Completion (trainer.prediction_step)"].append(pred_step_text) + table_data["Predicted Completion (model.generate)"].append( + prediction_text + ) + table_data[ + "Predicted Completion (trainer.prediction_step)" + ].append(pred_step_text) row_index += 1 if logger == "wandb": - wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] + wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] elif logger == "mlflow": - tracking_uri = AxolotlInputConfig(**self.cfg.to_dict()).mlflow_tracking_uri - mlflow.log_table(data=table_data, artifact_file="PredictionsVsGroundTruth.json", tracking_uri = tracking_uri) + tracking_uri = AxolotlInputConfig( + **self.cfg.to_dict() + ).mlflow_tracking_uri + mlflow.log_table( + data=table_data, + artifact_file="PredictionsVsGroundTruth.json", + tracking_uri=tracking_uri, + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader) From 622b28286190a33c61bf60ea35c881fb09c4d506 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Apr 2024 16:09:17 -0400 Subject: [PATCH 3/4] add additional clause for mlflow as it's optional --- src/axolotl/utils/callbacks/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index ba2f439b27..890883512e 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Dict, List import evaluate -import mlflow import numpy as np import pandas as pd import torch @@ -28,6 +27,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from axolotl.core.trainer_builder import is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import ( @@ -721,7 +721,9 @@ def log_table_from_dataloader(name: str, table_dataloader): row_index += 1 if logger == "wandb": wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] - elif logger == "mlflow": + elif logger == "mlflow" and is_mlflow_available(): + import mlflow + tracking_uri = AxolotlInputConfig( **self.cfg.to_dict() ).mlflow_tracking_uri From 3b13b54a93a95abc3ad1c7322319223450323893 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Apr 2024 16:21:00 -0400 Subject: [PATCH 4/4] Fix circular imports --- src/axolotl/core/trainer_builder.py | 5 +---- src/axolotl/utils/__init__.py | 8 ++++++++ src/axolotl/utils/callbacks/__init__.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8bdcfba268..35318b836d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -36,6 +36,7 @@ from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.utils import is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -71,10 +72,6 @@ LOG = logging.getLogger("axolotl.core.trainer_builder") -def is_mlflow_available(): - return importlib.util.find_spec("mlflow") is not None - - def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): if isinstance(tag_names, str): tag_names = [tag_names] diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index e69de29bb2..99dec79f1b 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -0,0 +1,8 @@ +""" +Basic utils for Axolotl +""" +import importlib + + +def is_mlflow_available(): + return importlib.util.find_spec("mlflow") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 890883512e..d907e3f6a3 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -27,7 +27,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy -from axolotl.core.trainer_builder import is_mlflow_available +from axolotl.utils import is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import (