diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index f884903837..4415b3a639 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -467,6 +467,11 @@ def validate_config(cfg): "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." ) + if cfg.test_datasets and cfg.val_set_size: + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 8ef3a7f78a..4fbf0b602e 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from datasets import ( @@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_main_process()): - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if cfg.test_datasets: + train_dataset, _, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + ) + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + ) + else: + train_dataset, eval_dataset, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) else: path = cfg.pretraining_dataset name = None @@ -111,8 +119,12 @@ def prepare_dataset(cfg, tokenizer): def load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, + cfg, + default_dataset_prepared_path, + split="train", ) -> Tuple[DatasetDict, List[Prompter]]: + cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( @@ -123,7 +135,7 @@ def load_tokenized_prepared_datasets( sorted( [ f"{d.path}:{d.type}:{d.shards}:{d.conversation}" - for d in cfg.datasets + for d in cfg_datasets ] ) ) @@ -146,7 +158,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -177,8 +189,8 @@ def for_d_in_datasets(dataset_configs): yield dataset # pylint: disable=invalid-name - for config_dataset in for_d_in_datasets(cfg.datasets): - ds: Union[Dataset, DatasetDict] = None + for config_dataset in for_d_in_datasets(cfg_datasets): + ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False try: load_dataset( @@ -331,16 +343,6 @@ def for_d_in_datasets(dataset_configs): ) if not ds: raise ValueError("unhandled dataset load") - # support for using a subset of the data - if config_dataset.shards: - if "train" in ds: - ds = ds.shuffle(seed=seed)["train"].shard( - num_shards=config_dataset.shards, index=0 - ) - else: - ds = ds.shuffle(seed=seed).shard( - num_shards=config_dataset.shards, index=0 - ) d_base_type = d_prompt_style = None d_type = config_dataset.type @@ -348,17 +350,21 @@ def for_d_in_datasets(dataset_configs): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if "train" in ds: - ds = ds["train"] - elif ( - isinstance(ds, DatasetDict) - and config_dataset.train_on_split - and config_dataset.train_on_split in ds - ): - ds = ds[config_dataset.train_on_split] + + if config_dataset.split and config_dataset.split in ds: + ds = ds[config_dataset.split] + elif split in ds: + ds = ds[split] elif isinstance(ds, DatasetDict): raise ValueError( - f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" + f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" + ) + + # support for using a subset of the data + if config_dataset.shards: + shards_idx = config_dataset.get("shards_idx", 0) + ds = ds.shuffle(seed=seed).shard( + num_shards=config_dataset.shards, index=shards_idx ) dataset_wrapper, dataset_prompter = get_dataset_wrapper( @@ -414,7 +420,9 @@ def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, + split="train", ) -> Tuple[Dataset, Dataset, List[Prompter]]: + cfg_datasets = cfg.eval_datasets if split != "train" else cfg.datasets max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -435,7 +443,7 @@ def load_prepare_datasets( + str(max_packed_sequence_len) + seed + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg_datasets]) ) + "|" + tokenizer_name @@ -459,7 +467,7 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -480,7 +488,7 @@ def load_prepare_datasets( ) else: dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, cfg, default_dataset_prepared_path, split=split ) if cfg.seed: @@ -534,7 +542,7 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - if cfg.val_set_size: + if split == "train" and cfg.val_set_size: # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access @@ -567,6 +575,9 @@ def load_prepare_datasets( train_dataset = dataset["train"] eval_dataset = dataset["test"] + elif split == "test": + train_dataset = None + eval_dataset = dataset else: train_dataset = dataset eval_dataset = None