diff --git a/tests/test_training.py b/tests/test_training.py index f83c4ebd3c..cd3ebced14 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import copy import os +import pathlib import shutil import sys from argparse import Namespace @@ -22,10 +23,9 @@ from scripts.train.train import main # noqa: E402 -def create_c4_dataset_xsmall(prefix: str) -> str: +def create_c4_dataset_xsmall(path: pathlib.Path) -> str: """Creates a small mocked version of the C4 dataset.""" - c4_dir = os.path.join(os.getcwd(), f'my-copy-c4-{prefix}') - shutil.rmtree(c4_dir, ignore_errors=True) + c4_dir = os.path.join(path, f'my-copy-c4') downloaded_split = 'val_xsmall' # very fast to convert # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 @@ -55,10 +55,9 @@ def create_c4_dataset_xsmall(prefix: str) -> str: return c4_dir -def create_arxiv_dataset(prefix: str) -> str: +def create_arxiv_dataset(path: pathlib.Path) -> str: """Creates an arxiv dataset.""" - arxiv_dir = os.path.join(os.getcwd(), f'my-copy-arxiv-{prefix}') - shutil.rmtree(arxiv_dir, ignore_errors=True) + arxiv_dir = os.path.join(path, f'my-copy-arxiv') downloaded_split = 'train' main_json( @@ -75,7 +74,6 @@ def create_arxiv_dataset(prefix: str) -> str: 'num_workers': None })) - assert os.path.exists(arxiv_dir) return arxiv_dir @@ -179,16 +177,16 @@ def test_train_gauntlet(set_correct_cwd: Any): assert inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1][-1] == 0 -def test_train_multi_eval(set_correct_cwd: Any): +def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" - c4_dataset_name = create_c4_dataset_xsmall('multi-eval') + c4_dataset_name = create_c4_dataset_xsmall(tmp_path) test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') # Set up multiple eval dataloaders first_eval_loader = test_cfg.eval_loader first_eval_loader.label = 'c4' # Create second eval dataloader using the arxiv dataset. second_eval_loader = copy.deepcopy(first_eval_loader) - arxiv_dataset_name = create_arxiv_dataset('multi-eval') + arxiv_dataset_name = create_arxiv_dataset(tmp_path) second_eval_loader.data_local = arxiv_dataset_name second_eval_loader.label = 'arxiv' test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])