Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanHB committed Dec 3, 2024
1 parent 3481562 commit b0ca7f1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR: str = os.getenv("HF_HOME", "/scratch")
print(f"CACHE_DIR: {CACHE_DIR}")
print(f"ENV: {os.environ}")

HELP_PANNEL_NAME_1 = "Common Paramaters"
HELP_PANNEL_NAME_2 = "Logging Parameters"
Expand Down Expand Up @@ -63,7 +65,7 @@ def accelerate( # noqa C901
] = None,
cache_dir: Annotated[
str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANNEL_NAME_1)
] = CACHE_DIR,
] = "cache/model",
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANNEL_NAME_1)
] = 1,
Expand Down Expand Up @@ -113,6 +115,7 @@ def accelerate( # noqa C901
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
print(f"ENV CONFIG: {env_config}")
evaluation_tracker = EvaluationTracker(
output_dir=output_dir,
save_details=save_details,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run_model_predictions_full(model: str, tasks: tuple):
output_dir="",
dataset_loading_processes=1,
save_details=True,
cache_dir="cache/models",
# cache_dir="cache/models",
)
return results

Expand All @@ -75,7 +75,7 @@ def run_model_predictions_lite(model: str, tasks: tuple):
dataset_loading_processes=1,
save_details=True,
max_samples=10,
cache_dir="cache/models",
# cache_dir="cache/models",
)
return results

Expand Down

0 comments on commit b0ca7f1

Please sign in to comment.