From 934c4f9a9ff4552c05dcb325af052c6394e05ca0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Apr 2024 15:55:48 -0400 Subject: [PATCH] chore: lint --- src/axolotl/core/trainer_builder.py | 6 +++++- src/axolotl/utils/callbacks/__init__.py | 24 +++++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e3a1c1cf6c..8bdcfba268 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -946,7 +946,11 @@ def get_post_trainer_create_callbacks(self, trainer): trainer, self.tokenizer, "wandb" ) callbacks.append(LogPredictionCallback(self.cfg)) - if self.cfg.use_mlflow and is_mlflow_available() and self.cfg.eval_table_size > 0: + if ( + self.cfg.use_mlflow + and is_mlflow_available() + and self.cfg.eval_table_size > 0 + ): LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "mlflow" ) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 7dd9402999..ba2f439b27 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -6,7 +6,7 @@ import os from shutil import copyfile from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List import evaluate import mlflow @@ -599,7 +599,7 @@ def find_ranges(lst): return ranges def log_table_from_dataloader(name: str, table_dataloader): - table_data = { + table_data: Dict[str, List[Any]] = { "id": [], "Prompt": [], "Correct Completion": [], @@ -712,14 +712,24 @@ def log_table_from_dataloader(name: str, table_dataloader): table_data["id"].append(row_index) table_data["Prompt"].append(prompt_text) table_data["Correct Completion"].append(completion_text) - table_data["Predicted Completion (model.generate)"].append(prediction_text) - table_data["Predicted Completion (trainer.prediction_step)"].append(pred_step_text) + table_data["Predicted Completion (model.generate)"].append( + prediction_text + ) + table_data[ + "Predicted Completion (trainer.prediction_step)" + ].append(pred_step_text) row_index += 1 if logger == "wandb": - wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] + wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] elif logger == "mlflow": - tracking_uri = AxolotlInputConfig(**self.cfg.to_dict()).mlflow_tracking_uri - mlflow.log_table(data=table_data, artifact_file="PredictionsVsGroundTruth.json", tracking_uri = tracking_uri) + tracking_uri = AxolotlInputConfig( + **self.cfg.to_dict() + ).mlflow_tracking_uri + mlflow.log_table( + data=table_data, + artifact_file="PredictionsVsGroundTruth.json", + tracking_uri=tracking_uri, + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader)