From 50d56106623ee998ed8153fa3dd742fb6ee345fe Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 26 Sep 2024 20:05:25 +0000 Subject: [PATCH] fix tests --- .../data_prep/convert_dataset_json.py | 2 +- tests/a_scripts/eval/test_eval.py | 9 +++++---- tests/a_scripts/train/test_train.py | 18 ++++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py index 35d7e637e6..c6f7d51c02 100644 --- a/llmfoundry/command_utils/data_prep/convert_dataset_json.py +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -43,7 +43,7 @@ def build_hf_dataset( no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). + Typically "all" (The Pile) or "en" (allenai/c4). Returns: An IterableDataset. diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index 3fc7141b9a..f1b76913d1 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -157,16 +157,17 @@ def test_loader_eval( print(inmemorylogger.data.keys()) # Checks for first eval dataloader - assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert 'metrics/eval/allenai/c4/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'], list, ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], tuple, ) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 4f6a2e2ed9..b1bca9ebd0 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -154,16 +154,17 @@ def test_train_multi_eval(tmp_path: pathlib.Path): assert isinstance(inmemorylogger, InMemoryLogger) # Checks for first eval dataloader - assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert 'metrics/eval/allenai/c4/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'], list, ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], tuple, ) @@ -226,15 +227,16 @@ def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): 0] # pyright: ignore [reportGeneralTypeIssues] assert isinstance(inmemorylogger, InMemoryLogger) - assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert 'metrics/eval/allenai/c4/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'], list, ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], tuple, )