From 8670bedeb720502e2efc5ba6d92f89bc93632902 Mon Sep 17 00:00:00 2001 From: Scott Fleming Date: Tue, 9 Apr 2024 08:27:57 -0700 Subject: [PATCH 1/2] Correctly handle splits for datasets.arrow_dataset.Dataset objects The `load_tokenized_prepared_datasets` function currently has logic for loading a dataset from local path that always checks if a split is in the dataset. The problem is, if the dataset is loaded using `load_from_disk` and it is an Arrow-based dataset, *there is no* split information. Instead what happens is, by calling `split in ds`, it presumably searches through all the rows and columns of the arrow dataset object to find e.g., 'train' assuming `split == 'train'`. This causes the program to hang. See https://chat.openai.com/share/0d567dbd-d60b-4079-9040-e1de58a4dff3 for context. --- src/axolotl/utils/data/sft.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 9423b9a6f3..e51ab106d7 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -378,15 +378,16 @@ def for_d_in_datasets(dataset_configs): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - - if config_dataset.split and config_dataset.split in ds: - ds = ds[config_dataset.split] - elif split in ds: - ds = ds[split] - elif isinstance(ds, DatasetDict): - raise ValueError( - f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" - ) + + if isinstance(ds, DatasetDict): + if config_dataset.split and config_dataset.split in ds: + ds = ds[config_dataset.split] + elif split in ds: + ds = ds[split] + else: + raise ValueError( + f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" + ) # support for using a subset of the data if config_dataset.shards: From c0b99b7be1ae6ee3281dc16a077c32aadae8664f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Apr 2024 11:56:02 -0400 Subject: [PATCH 2/2] chore: lint --- src/axolotl/utils/data/sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index e51ab106d7..39c50b1a07 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -378,7 +378,7 @@ def for_d_in_datasets(dataset_configs): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - + if isinstance(ds, DatasetDict): if config_dataset.split and config_dataset.split in ds: ds = ds[config_dataset.split]