diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6c9bc68159..e9e0f4fa69 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,4 +1,5 @@ """Module containing data utilities""" + import functools import hashlib import logging @@ -223,7 +224,7 @@ def for_d_in_datasets(dataset_configs): token=use_auth_token, ) ds_from_hub = True - except (FileNotFoundError, ConnectionError, HFValidationError): + except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): pass ds_from_cloud = False @@ -290,14 +291,17 @@ def for_d_in_datasets(dataset_configs): local_path = Path(config_dataset.path) if local_path.exists(): if local_path.is_dir(): - # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=False, - split=None, - ) + if config_dataset.data_files: + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.data_files, + streaming=False, + split=None, + ) + else: + ds = load_from_disk(config_dataset.path) elif local_path.is_file(): ds_type = get_ds_type(config_dataset) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000000..8b7b3dae6a --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,272 @@ +""" +Test dataset loading under various conditions. +""" + +import shutil +import tempfile +import unittest +from pathlib import Path + +from datasets import Dataset +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer + +from axolotl.utils.data import load_tokenized_prepared_datasets +from axolotl.utils.dict import DictDefault + + +class TestDatasetPreparation(unittest.TestCase): + """Test a configured dataloader.""" + + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + } + ) + # Alpaca dataset. + self.dataset = Dataset.from_list( + [ + { + "instruction": "Evaluate this sentence for spelling and grammar mistakes", + "input": "He finnished his meal and left the resturant", + "output": "He finished his meal and left the restaurant.", + } + ] + ) + + def test_load_hub(self): + """Core use case. Verify that processing data from the hub works""" + with tempfile.TemporaryDirectory() as tmp_dir: + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_local_hub(self): + """Niche use case. Verify that a local copy of a hub dataset can be loaded""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + ) + + prepared_path = Path(tmp_dir) / "prepared" + # Right now a local copy that doesn't fully conform to a dataset + # must list data_files and ds_type otherwise the loader won't know + # how to load it. + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", + ], + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) + + def test_load_from_save_to_disk(self): + """Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_name = Path(tmp_dir) / "tmp_dataset" + self.dataset.save_to_disk(tmp_ds_name) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 256, + "datasets": [ + { + "path": str(tmp_ds_name), + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 1 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_from_dir_of_parquet(self): + """Usual use case. Verify a directory of parquet files can be loaded.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" + tmp_ds_dir.mkdir() + tmp_ds_path = tmp_ds_dir / "shard1.parquet" + self.dataset.to_parquet(tmp_ds_path) + + prepared_path: Path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 256, + "datasets": [ + { + "path": str(tmp_ds_dir), + "ds_type": "parquet", + "name": "test_data", + "data_files": [ + str(tmp_ds_path), + ], + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 1 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_from_dir_of_json(self): + """Standard use case. Verify a directory of json files can be loaded.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" + tmp_ds_dir.mkdir() + tmp_ds_path = tmp_ds_dir / "shard1.json" + self.dataset.to_json(tmp_ds_path) + + prepared_path: Path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 256, + "datasets": [ + { + "path": str(tmp_ds_dir), + "ds_type": "json", + "name": "test_data", + "data_files": [ + str(tmp_ds_path), + ], + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 1 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_from_single_parquet(self): + """Standard use case. Verify a single parquet file can be loaded.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet" + self.dataset.to_parquet(tmp_ds_path) + + prepared_path: Path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 256, + "datasets": [ + { + "path": str(tmp_ds_path), + "name": "test_data", + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 1 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_from_single_json(self): + """Standard use case. Verify a single json file can be loaded.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json" + self.dataset.to_json(tmp_ds_path) + + prepared_path: Path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 256, + "datasets": [ + { + "path": str(tmp_ds_path), + "name": "test_data", + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 1 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + +if __name__ == "__main__": + unittest.main()