diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8d08d60b36..e3a1c1cf6c 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -943,7 +943,12 @@ def get_post_trainer_create_callbacks(self, trainer): callbacks = [] if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer + trainer, self.tokenizer, "wandb" + ) + callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.use_mlflow and is_mlflow_available() and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 6a489f6c0e..7dd9402999 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Dict, List import evaluate +import mlflow import numpy as np import pandas as pd import torch @@ -28,6 +29,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import ( barrier, broadcast_dict, @@ -540,7 +542,7 @@ def predict_with_generate(): return CausalLMBenchEvalCallback -def log_prediction_callback_factory(trainer: Trainer, tokenizer): +def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -597,15 +599,13 @@ def find_ranges(lst): return ranges def log_table_from_dataloader(name: str, table_dataloader): - table = wandb.Table( # type: ignore[attr-defined] - columns=[ - "id", - "Prompt", - "Correct Completion", - "Predicted Completion (model.generate)", - "Predicted Completion (trainer.prediction_step)", - ] - ) + table_data = { + "id": [], + "Prompt": [], + "Correct Completion": [], + "Predicted Completion (model.generate)": [], + "Predicted Completion (trainer.prediction_step)": [], + } row_index = 0 for batch in tqdm(table_dataloader): @@ -709,16 +709,17 @@ def log_table_from_dataloader(name: str, table_dataloader): ) in zip( prompt_texts, completion_texts, predicted_texts, pred_step_texts ): - table.add_data( - row_index, - prompt_text, - completion_text, - prediction_text, - pred_step_text, - ) + table_data["id"].append(row_index) + table_data["Prompt"].append(prompt_text) + table_data["Correct Completion"].append(completion_text) + table_data["Predicted Completion (model.generate)"].append(prediction_text) + table_data["Predicted Completion (trainer.prediction_step)"].append(pred_step_text) row_index += 1 - - wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined] + if logger == "wandb": + wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] + elif logger == "mlflow": + tracking_uri = AxolotlInputConfig(**self.cfg.to_dict()).mlflow_tracking_uri + mlflow.log_table(data=table_data, artifact_file="PredictionsVsGroundTruth.json", tracking_uri = tracking_uri) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader)