From 3bcdab41297843fc23b167bb025327dc3e875982 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 16 Nov 2023 13:21:52 -0500 Subject: [PATCH] support for explicit test_dataset definition for evals --- src/axolotl/utils/config.py | 5 +++ src/axolotl/utils/data.py | 75 +++++++++++++++++++++---------------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index d2db92a633..c68b54c2e6 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -372,6 +372,11 @@ def validate_config(cfg): if cfg.rope_scaling: LOG.warning("`rope_scaling` should now be be a key under `model_config`") + if cfg.test_datasets and cfg.val_set_size: + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 49b36202c0..ea97d00cd7 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import hashlib import logging from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from datasets import ( @@ -60,9 +60,17 @@ def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_main_process()): - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if cfg.test_datasets: + train_dataset, _, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + ) + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + ) + else: + train_dataset, eval_dataset, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) else: train_dataset = load_pretraining_dataset( cfg.pretraining_dataset, @@ -98,8 +106,12 @@ def prepare_dataset(cfg, tokenizer): def load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, + cfg, + default_dataset_prepared_path, + split="train", ) -> Tuple[DatasetDict, List[Prompter]]: + cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( @@ -110,7 +122,7 @@ def load_tokenized_prepared_datasets( sorted( [ f"{d.path}:{d.type}:{d.shards}:{d.conversation}" - for d in cfg.datasets + for d in cfg_datasets ] ) ) @@ -133,7 +145,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -164,8 +176,8 @@ def for_d_in_datasets(dataset_configs): yield dataset # pylint: disable=invalid-name - for config_dataset in for_d_in_datasets(cfg.datasets): - ds: Union[Dataset, DatasetDict] = None + for config_dataset in for_d_in_datasets(cfg_datasets): + ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False try: load_dataset( @@ -311,16 +323,6 @@ def for_d_in_datasets(dataset_configs): ) if not ds: raise ValueError("unhandled dataset load") - # support for using a subset of the data - if config_dataset.shards: - if "train" in ds: - ds = ds.shuffle(seed=seed)["train"].shard( - num_shards=config_dataset.shards, index=0 - ) - else: - ds = ds.shuffle(seed=seed).shard( - num_shards=config_dataset.shards, index=0 - ) d_base_type = d_prompt_style = None d_type = config_dataset.type @@ -328,17 +330,21 @@ def for_d_in_datasets(dataset_configs): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if "train" in ds: - ds = ds["train"] - elif ( - isinstance(ds, DatasetDict) - and config_dataset.train_on_split - and config_dataset.train_on_split in ds - ): - ds = ds[config_dataset.train_on_split] + + if config_dataset.split and config_dataset.split in ds: + ds = ds[config_dataset.split] + elif split in ds: + ds = ds[split] elif isinstance(ds, DatasetDict): raise ValueError( - f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" + f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" + ) + + # support for using a subset of the data + if config_dataset.shards: + shards_idx = config_dataset.get("shards_idx", 0) + ds = ds.shuffle(seed=seed).shard( + num_shards=config_dataset.shards, index=shards_idx ) dataset_wrapper, dataset_prompter = get_dataset_wrapper( @@ -394,7 +400,9 @@ def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, + split="train", ) -> Tuple[Dataset, Dataset, List[Prompter]]: + cfg_datasets = cfg.eval_datasets if split != "train" else cfg.datasets max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -415,7 +423,7 @@ def load_prepare_datasets( + str(max_packed_sequence_len) + seed + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg_datasets]) ) + "|" + tokenizer_name @@ -439,7 +447,7 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -460,7 +468,7 @@ def load_prepare_datasets( ) else: dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, cfg, default_dataset_prepared_path, split=split ) if cfg.seed: @@ -514,7 +522,7 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - if cfg.val_set_size: + if split == "train" and cfg.val_set_size: # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access @@ -547,6 +555,9 @@ def load_prepare_datasets( train_dataset = dataset["train"] eval_dataset = dataset["test"] + elif split == "test": + train_dataset = None + eval_dataset = dataset else: train_dataset = dataset eval_dataset = None