diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index e756969187..78fbe52d29 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -519,6 +519,11 @@ def validate_config(cfg): "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" ) + if cfg.test_datasets and cfg.val_set_size: + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6726f2ad14..3691a6e145 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from datasets import ( @@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_main_process()): - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if cfg.test_datasets: + train_dataset, _, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + ) + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + ) + else: + train_dataset, eval_dataset, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) else: path = cfg.pretraining_dataset name = None @@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer): def load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, + cfg, + default_dataset_prepared_path, + split="train", ) -> Tuple[DatasetDict, List[Prompter]]: + cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( @@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets( sorted( [ f"{d.path}:{d.type}:{d.shards}:{d.conversation}" - for d in cfg.datasets + for d in cfg_datasets ] ) ) @@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -188,8 +200,8 @@ def for_d_in_datasets(dataset_configs): yield dataset # pylint: disable=invalid-name - for config_dataset in for_d_in_datasets(cfg.datasets): - ds: Union[Dataset, DatasetDict] = None + for config_dataset in for_d_in_datasets(cfg_datasets): + ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False try: load_dataset( @@ -342,16 +354,6 @@ def for_d_in_datasets(dataset_configs): ) if not ds: raise ValueError("unhandled dataset load") - # support for using a subset of the data - if config_dataset.shards: - if "train" in ds: - ds = ds.shuffle(seed=seed)["train"].shard( - num_shards=config_dataset.shards, index=0 - ) - else: - ds = ds.shuffle(seed=seed).shard( - num_shards=config_dataset.shards, index=0 - ) d_base_type = d_prompt_style = None d_type = config_dataset.type @@ -359,17 +361,21 @@ def for_d_in_datasets(dataset_configs): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if "train" in ds: - ds = ds["train"] - elif ( - isinstance(ds, DatasetDict) - and config_dataset.train_on_split - and config_dataset.train_on_split in ds - ): - ds = ds[config_dataset.train_on_split] + + if config_dataset.split and config_dataset.split in ds: + ds = ds[config_dataset.split] + elif split in ds: + ds = ds[split] elif isinstance(ds, DatasetDict): raise ValueError( - f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" + f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" + ) + + # support for using a subset of the data + if config_dataset.shards: + shards_idx = config_dataset.get("shards_idx", 0) + ds = ds.shuffle(seed=seed).shard( + num_shards=config_dataset.shards, index=shards_idx ) dataset_wrapper, dataset_prompter = get_dataset_wrapper( @@ -428,6 +434,7 @@ def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, + split="train", ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path @@ -442,7 +449,7 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - if cfg.val_set_size: + if split == "train" and cfg.val_set_size: # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access @@ -475,6 +482,9 @@ def load_prepare_datasets( train_dataset = dataset["train"] eval_dataset = dataset["test"] + elif split == "test": + train_dataset = None + eval_dataset = dataset else: train_dataset = dataset eval_dataset = None