diff --git a/README.md b/README.md index 1629ae251d..41521f081c 100644 --- a/README.md +++ b/README.md @@ -678,6 +678,10 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: +# If false, the datasets will not be shuffled and will keep their original order in `datasets`. +# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. +shuffle_merged_datasets: true + # A list of one or more datasets to eval the model with. # You can use either test_datasets, or val_set_size, but not both. test_datasets: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index ef31c05c22..b1c395bcc8 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -416,6 +416,7 @@ class Config: datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore + shuffle_merged_datasets: Optional[bool] = True dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None dataset_shard_idx: Optional[int] = None diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index ad3a5cb2d8..9e0049e659 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -415,8 +415,11 @@ def for_d_in_datasets(dataset_configs): dataset = concatenate_datasets(datasets) if len(datasets) > 1: - LOG.info("shuffle merged datasets") - dataset = dataset.shuffle(seed=seed) + if cfg.shuffle_merged_datasets: + LOG.debug("shuffle merged datasets") + dataset = dataset.shuffle(seed=seed) + else: + LOG.debug("NOT shuffling merged datasets") dataset, _ = process_datasets_for_packing(cfg, dataset, None) @@ -819,7 +822,11 @@ def wrap_pretraining_dataset( else: encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + if cfg.shuffle_merged_datasets: + dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + else: + LOG.debug("NOT shuffling merged pretraining datasets") + dataset = dataset.map( encode, batched=True,