diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index 6c9bc68159..e9e0f4fa69 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -1,4 +1,5 @@
"""Module containing data utilities"""
+
import functools
import hashlib
import logging
@@ -223,7 +224,7 @@ def for_d_in_datasets(dataset_configs):
token=use_auth_token,
)
ds_from_hub = True
- except (FileNotFoundError, ConnectionError, HFValidationError):
+ except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
@@ -290,14 +291,17 @@ def for_d_in_datasets(dataset_configs):
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
- # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
- ds = load_dataset(
- config_dataset.path,
- name=config_dataset.name,
- data_files=config_dataset.data_files,
- streaming=False,
- split=None,
- )
+ if config_dataset.data_files:
+ ds_type = get_ds_type(config_dataset)
+ ds = load_dataset(
+ ds_type,
+ name=config_dataset.name,
+ data_files=config_dataset.data_files,
+ streaming=False,
+ split=None,
+ )
+ else:
+ ds = load_from_disk(config_dataset.path)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
new file mode 100644
index 0000000000..8b7b3dae6a
--- /dev/null
+++ b/tests/test_datasets.py
@@ -0,0 +1,272 @@
+"""
+Test dataset loading under various conditions.
+"""
+
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+
+from datasets import Dataset
+from huggingface_hub import snapshot_download
+from transformers import AutoTokenizer
+
+from axolotl.utils.data import load_tokenized_prepared_datasets
+from axolotl.utils.dict import DictDefault
+
+
+class TestDatasetPreparation(unittest.TestCase):
+ """Test a configured dataloader."""
+
+ def setUp(self) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
+ self.tokenizer.add_special_tokens(
+ {
+ "bos_token": "",
+ "eos_token": "",
+ "unk_token": "",
+ }
+ )
+ # Alpaca dataset.
+ self.dataset = Dataset.from_list(
+ [
+ {
+ "instruction": "Evaluate this sentence for spelling and grammar mistakes",
+ "input": "He finnished his meal and left the resturant",
+ "output": "He finished his meal and left the restaurant.",
+ }
+ ]
+ )
+
+ def test_load_hub(self):
+ """Core use case. Verify that processing data from the hub works"""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ prepared_path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 2000
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+ def test_load_local_hub(self):
+ """Niche use case. Verify that a local copy of a hub dataset can be loaded"""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
+ tmp_ds_path.mkdir(parents=True, exist_ok=True)
+ snapshot_download(
+ repo_id="mhenrichsen/alpaca_2k_test",
+ repo_type="dataset",
+ local_dir=tmp_ds_path,
+ )
+
+ prepared_path = Path(tmp_dir) / "prepared"
+ # Right now a local copy that doesn't fully conform to a dataset
+ # must list data_files and ds_type otherwise the loader won't know
+ # how to load it.
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "ds_type": "parquet",
+ "type": "alpaca",
+ "data_files": [
+ "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
+ ],
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 2000
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+ shutil.rmtree(tmp_ds_path)
+
+ def test_load_from_save_to_disk(self):
+ """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
+ self.dataset.save_to_disk(tmp_ds_name)
+
+ prepared_path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 256,
+ "datasets": [
+ {
+ "path": str(tmp_ds_name),
+ "type": "alpaca",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 1
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+ def test_load_from_dir_of_parquet(self):
+ """Usual use case. Verify a directory of parquet files can be loaded."""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
+ tmp_ds_dir.mkdir()
+ tmp_ds_path = tmp_ds_dir / "shard1.parquet"
+ self.dataset.to_parquet(tmp_ds_path)
+
+ prepared_path: Path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 256,
+ "datasets": [
+ {
+ "path": str(tmp_ds_dir),
+ "ds_type": "parquet",
+ "name": "test_data",
+ "data_files": [
+ str(tmp_ds_path),
+ ],
+ "type": "alpaca",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 1
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+ def test_load_from_dir_of_json(self):
+ """Standard use case. Verify a directory of json files can be loaded."""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
+ tmp_ds_dir.mkdir()
+ tmp_ds_path = tmp_ds_dir / "shard1.json"
+ self.dataset.to_json(tmp_ds_path)
+
+ prepared_path: Path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 256,
+ "datasets": [
+ {
+ "path": str(tmp_ds_dir),
+ "ds_type": "json",
+ "name": "test_data",
+ "data_files": [
+ str(tmp_ds_path),
+ ],
+ "type": "alpaca",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 1
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+ def test_load_from_single_parquet(self):
+ """Standard use case. Verify a single parquet file can be loaded."""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
+ self.dataset.to_parquet(tmp_ds_path)
+
+ prepared_path: Path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 256,
+ "datasets": [
+ {
+ "path": str(tmp_ds_path),
+ "name": "test_data",
+ "type": "alpaca",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 1
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+ def test_load_from_single_json(self):
+ """Standard use case. Verify a single json file can be loaded."""
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
+ self.dataset.to_json(tmp_ds_path)
+
+ prepared_path: Path = Path(tmp_dir) / "prepared"
+ cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 256,
+ "datasets": [
+ {
+ "path": str(tmp_ds_path),
+ "name": "test_data",
+ "type": "alpaca",
+ },
+ ],
+ }
+ )
+
+ dataset, _ = load_tokenized_prepared_datasets(
+ self.tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 1
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+
+
+if __name__ == "__main__":
+ unittest.main()