Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Apr 9, 2024
1 parent 2b87bde commit 934c4f9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
6 changes: 5 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,11 @@ def get_post_trainer_create_callbacks(self, trainer):
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_mlflow and is_mlflow_available() and self.cfg.eval_table_size > 0:
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
Expand Down
24 changes: 17 additions & 7 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List

import evaluate
import mlflow
Expand Down Expand Up @@ -599,7 +599,7 @@ def find_ranges(lst):
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table_data = {
table_data: Dict[str, List[Any]] = {
"id": [],
"Prompt": [],
"Correct Completion": [],
Expand Down Expand Up @@ -712,14 +712,24 @@ def log_table_from_dataloader(name: str, table_dataloader):
table_data["id"].append(row_index)
table_data["Prompt"].append(prompt_text)
table_data["Correct Completion"].append(completion_text)
table_data["Predicted Completion (model.generate)"].append(prediction_text)
table_data["Predicted Completion (trainer.prediction_step)"].append(pred_step_text)
table_data["Predicted Completion (model.generate)"].append(
prediction_text
)
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow":
tracking_uri = AxolotlInputConfig(**self.cfg.to_dict()).mlflow_tracking_uri
mlflow.log_table(data=table_data, artifact_file="PredictionsVsGroundTruth.json", tracking_uri = tracking_uri)
tracking_uri = AxolotlInputConfig(
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
Expand Down

0 comments on commit 934c4f9

Please sign in to comment.