Skip to content

Commit

Permalink
fix: loading locally downloaded dataset (#2056) [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 authored Nov 16, 2024
1 parent d42f202 commit fd70eec
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 37 deletions.
10 changes: 9 additions & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,15 @@ def for_d_in_datasets(dataset_configs):
split=None,
)
else:
ds = load_from_disk(config_dataset.path)
try:
ds = load_from_disk(config_dataset.path)
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)

Expand Down
107 changes: 71 additions & 36 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,44 +371,79 @@ def test_load_hub_with_revision_with_dpo(self):
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir2:
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)

prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)

prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)

dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)

assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)

def test_loading_local_dataset_folder(self):
"""Verify that a dataset downloaded to a local folder can be loaded"""

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)

prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": str(tmp_ds_path),
"type": "alpaca",
},
],
}
)

dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)

assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)


if __name__ == "__main__":
Expand Down

0 comments on commit fd70eec

Please sign in to comment.