Skip to content

Commit

Permalink
modified to use tmp_path
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 27, 2023
1 parent 87a92bf commit 3687be2
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import pathlib
import shutil
import sys
from argparse import Namespace
Expand All @@ -22,10 +23,9 @@
from scripts.train.train import main # noqa: E402


def create_c4_dataset_xsmall(prefix: str) -> str:
def create_c4_dataset_xsmall(path: pathlib.Path) -> str:
"""Creates a small mocked version of the C4 dataset."""
c4_dir = os.path.join(os.getcwd(), f'my-copy-c4-{prefix}')
shutil.rmtree(c4_dir, ignore_errors=True)
c4_dir = os.path.join(path, f'my-copy-c4')
downloaded_split = 'val_xsmall' # very fast to convert

# Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188
Expand Down Expand Up @@ -55,10 +55,9 @@ def create_c4_dataset_xsmall(prefix: str) -> str:
return c4_dir


def create_arxiv_dataset(prefix: str) -> str:
def create_arxiv_dataset(path: pathlib.Path) -> str:
"""Creates an arxiv dataset."""
arxiv_dir = os.path.join(os.getcwd(), f'my-copy-arxiv-{prefix}')
shutil.rmtree(arxiv_dir, ignore_errors=True)
arxiv_dir = os.path.join(path, f'my-copy-arxiv')
downloaded_split = 'train'

main_json(
Expand All @@ -75,7 +74,6 @@ def create_arxiv_dataset(prefix: str) -> str:
'num_workers': None
}))

assert os.path.exists(arxiv_dir)
return arxiv_dir


Expand Down Expand Up @@ -179,16 +177,16 @@ def test_train_gauntlet(set_correct_cwd: Any):
assert inmemorylogger.data['icl/metrics/eval_gauntlet/average'][-1][-1] == 0


def test_train_multi_eval(set_correct_cwd: Any):
def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xsmall('multi-eval')
c4_dataset_name = create_c4_dataset_xsmall(tmp_path)
test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu')
# Set up multiple eval dataloaders
first_eval_loader = test_cfg.eval_loader
first_eval_loader.label = 'c4'
# Create second eval dataloader using the arxiv dataset.
second_eval_loader = copy.deepcopy(first_eval_loader)
arxiv_dataset_name = create_arxiv_dataset('multi-eval')
arxiv_dataset_name = create_arxiv_dataset(tmp_path)
second_eval_loader.data_local = arxiv_dataset_name
second_eval_loader.label = 'arxiv'
test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader])
Expand Down

0 comments on commit 3687be2

Please sign in to comment.