From bf0804447c5f3cbaa65ea98ab6123b549c9629ce Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Sep 2023 10:36:16 -0400 Subject: [PATCH] fix wandb so mypy doesn't complain (#562) * fix wandb so mypy doesn't complain * fix wandb so mypy doesn't complain * no need for mypy override anymore --- requirements.txt | 1 + scripts/finetune.py | 2 +- src/axolotl/utils/callbacks.py | 2 +- src/axolotl/utils/{wandb.py => wandb_.py} | 0 4 files changed, 3 insertions(+), 2 deletions(-) rename src/axolotl/utils/{wandb.py => wandb_.py} (100%) diff --git a/requirements.txt b/requirements.txt index 1e95b716ec..4ef9f5fd2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,4 @@ scipy scikit-learn==1.2.2 pynvml art +wandb diff --git a/scripts/finetune.py b/scripts/finetune.py index b998edc798..ca72c79106 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -26,7 +26,7 @@ from axolotl.utils.distributed import is_main_process from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.wandb import setup_wandb_env_vars +from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 99c7b147a5..819360f1d3 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -367,7 +367,7 @@ def on_evaluate( output_scores=False, ) - def logits_to_tokens(logits) -> str: + def logits_to_tokens(logits) -> torch.Tensor: probabilities = torch.softmax(logits, dim=-1) # Get the predicted token ids (the ones with the highest probability) predicted_token_ids = torch.argmax(probabilities, dim=-1) diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb_.py similarity index 100% rename from src/axolotl/utils/wandb.py rename to src/axolotl/utils/wandb_.py