Skip to content

Commit

Permalink
parametrize
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Oct 24, 2023
1 parent c3ac5f9 commit 007272c
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def set_correct_cwd():
if os.getcwd().endswith('llm-foundry/scripts'):
os.chdir('..')


def test_train_gauntlet(set_correct_cwd: Any, tmp_path: pathlib.Path):
@pytest.mark.parametrize('averages', [
{"core_average": ["language_understanding_lite"]}, None
])
def test_train_gauntlet(averages, set_correct_cwd: Any, tmp_path: pathlib.Path):
"""Test training run with a small dataset."""
dataset_name = create_c4_dataset_xsmall(tmp_path)
test_cfg = gpt_tiny_cfg(dataset_name, 'cpu')
Expand Down Expand Up @@ -155,6 +157,9 @@ def test_train_gauntlet(set_correct_cwd: Any, tmp_path: pathlib.Path):
])
})

if averages is not None:
test_cfg.eval_gauntlet['averages'] = averages

test_cfg.icl_seq_len = 128
test_cfg.max_duration = '1ba'
test_cfg.eval_interval = '1ba'
Expand All @@ -167,14 +172,17 @@ def test_train_gauntlet(set_correct_cwd: Any, tmp_path: pathlib.Path):
inmemorylogger = trainer.logger.destinations[
0] # pyright: ignore [reportGeneralTypeIssues]
assert isinstance(inmemorylogger, InMemoryLogger)
assert 'icl/metrics/eval_gauntlet/default_average' in inmemorylogger.data.keys()
assert isinstance(inmemorylogger.data['icl/metrics/eval_gauntlet/default_average'],


category_name = 'default_average' if averages is None else 'core_average'
assert f'icl/metrics/eval_gauntlet/{category_name}' in inmemorylogger.data.keys()
assert isinstance(inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'],
list)
assert len(inmemorylogger.data['icl/metrics/eval_gauntlet/default_average'][-1]) > 0
assert len(inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'][-1]) > 0
assert isinstance(
inmemorylogger.data['icl/metrics/eval_gauntlet/default_average'][-1], tuple)
inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'][-1], tuple)

assert inmemorylogger.data['icl/metrics/eval_gauntlet/default_average'][-1][-1] == 0
assert inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'][-1][-1] == 0


def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path):
Expand Down

0 comments on commit 007272c

Please sign in to comment.