diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8d08d60b36..35318b836d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -36,6 +36,7 @@ from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.utils import is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -71,10 +72,6 @@ LOG = logging.getLogger("axolotl.core.trainer_builder") -def is_mlflow_available(): - return importlib.util.find_spec("mlflow") is not None - - def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): if isinstance(tag_names, str): tag_names = [tag_names] @@ -943,7 +940,16 @@ def get_post_trainer_create_callbacks(self, trainer): callbacks = [] if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer + trainer, self.tokenizer, "wandb" + ) + callbacks.append(LogPredictionCallback(self.cfg)) + if ( + self.cfg.use_mlflow + and is_mlflow_available() + and self.cfg.eval_table_size > 0 + ): + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index e69de29bb2..99dec79f1b 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -0,0 +1,8 @@ +""" +Basic utils for Axolotl +""" +import importlib + + +def is_mlflow_available(): + return importlib.util.find_spec("mlflow") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 6a489f6c0e..d907e3f6a3 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -6,7 +6,7 @@ import os from shutil import copyfile from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List import evaluate import numpy as np @@ -27,7 +27,9 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from axolotl.utils import is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import ( barrier, broadcast_dict, @@ -540,7 +542,7 @@ def predict_with_generate(): return CausalLMBenchEvalCallback -def log_prediction_callback_factory(trainer: Trainer, tokenizer): +def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -597,15 +599,13 @@ def find_ranges(lst): return ranges def log_table_from_dataloader(name: str, table_dataloader): - table = wandb.Table( # type: ignore[attr-defined] - columns=[ - "id", - "Prompt", - "Correct Completion", - "Predicted Completion (model.generate)", - "Predicted Completion (trainer.prediction_step)", - ] - ) + table_data: Dict[str, List[Any]] = { + "id": [], + "Prompt": [], + "Correct Completion": [], + "Predicted Completion (model.generate)": [], + "Predicted Completion (trainer.prediction_step)": [], + } row_index = 0 for batch in tqdm(table_dataloader): @@ -709,16 +709,29 @@ def log_table_from_dataloader(name: str, table_dataloader): ) in zip( prompt_texts, completion_texts, predicted_texts, pred_step_texts ): - table.add_data( - row_index, - prompt_text, - completion_text, - prediction_text, - pred_step_text, + table_data["id"].append(row_index) + table_data["Prompt"].append(prompt_text) + table_data["Correct Completion"].append(completion_text) + table_data["Predicted Completion (model.generate)"].append( + prediction_text ) + table_data[ + "Predicted Completion (trainer.prediction_step)" + ].append(pred_step_text) row_index += 1 - - wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined] + if logger == "wandb": + wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] + elif logger == "mlflow" and is_mlflow_available(): + import mlflow + + tracking_uri = AxolotlInputConfig( + **self.cfg.to_dict() + ).mlflow_tracking_uri + mlflow.log_table( + data=table_data, + artifact_file="PredictionsVsGroundTruth.json", + tracking_uri=tracking_uri, + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader)