diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 9423b9a6f3..39c50b1a07 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -379,14 +379,15 @@ def for_d_in_datasets(dataset_configs): d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if config_dataset.split and config_dataset.split in ds: - ds = ds[config_dataset.split] - elif split in ds: - ds = ds[split] - elif isinstance(ds, DatasetDict): - raise ValueError( - f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" - ) + if isinstance(ds, DatasetDict): + if config_dataset.split and config_dataset.split in ds: + ds = ds[config_dataset.split] + elif split in ds: + ds = ds[split] + else: + raise ValueError( + f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" + ) # support for using a subset of the data if config_dataset.shards: