From f2e9c14a395100a0fa211ad243119712292b5174 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 25 Oct 2023 20:55:03 -0400 Subject: [PATCH] support for explicit test_dataset definition for evals --- src/axolotl/utils/config.py | 5 +++ src/axolotl/utils/data.py | 62 +++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 81660ae658..063cfee406 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -369,6 +369,11 @@ def validate_config(cfg): "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." ) + if cfg.test_datasets and cfg.val_set_size: + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 99697de32d..563a063dbf 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import hashlib import logging from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from datasets import ( @@ -57,9 +57,17 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: def prepare_dataset(cfg, tokenizer): if not cfg.pretraining_dataset: with zero_first(is_main_process()): - train_dataset, eval_dataset = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if cfg.test_datasets: + train_dataset, _ = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + ) + _, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + ) + else: + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) else: train_dataset = load_pretraining_dataset( cfg.pretraining_dataset, @@ -87,8 +95,12 @@ def prepare_dataset(cfg, tokenizer): def load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, + cfg, + default_dataset_prepared_path, + split="train", ) -> DatasetDict: + cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( @@ -96,7 +108,7 @@ def load_tokenized_prepared_datasets( str(cfg.sequence_len) + "@" + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg_datasets]) ) + "|" + tokenizer_name @@ -116,7 +128,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -147,8 +159,8 @@ def for_d_in_datasets(dataset_configs): yield dataset # pylint: disable=invalid-name - for d in for_d_in_datasets(cfg.datasets): - ds: Union[Dataset, DatasetDict] = None + for d in for_d_in_datasets(cfg_datasets): + ds: Optional[Union[Dataset, DatasetDict]] = None ds_from_hub = False try: load_dataset( @@ -232,8 +244,9 @@ def for_d_in_datasets(dataset_configs): raise ValueError("unhandled dataset load") # support for using a subset of the data if d.shards: - if "train" in ds: - ds = ds.shuffle(seed=seed)["train"].shard( + shard_split = d.split if d.split else split + if shard_split in ds: + ds = ds.shuffle(seed=seed)[shard_split].shard( num_shards=d.shards, index=0 ) else: @@ -245,17 +258,13 @@ def for_d_in_datasets(dataset_configs): d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if "train" in ds: - ds = ds["train"] - elif ( - isinstance(ds, DatasetDict) - and d.train_on_split - and d.train_on_split in ds - ): - ds = ds[d.train_on_split] + if d.split and d.split in ds: + ds = ds[d.split] + elif split in ds: + ds = ds[split] elif isinstance(ds, DatasetDict): raise ValueError( - f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `" + f"no {split} split found for dataset {d.path}, you may specify a split with 'split: ...`" ) if ( "input_ids" in ds.features @@ -375,7 +384,9 @@ def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, + split: str = "train", ) -> Tuple[Dataset, Dataset]: + cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -395,7 +406,7 @@ def load_prepare_datasets( + str(max_packed_sequence_len) + seed + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg_datasets]) ) + "|" + tokenizer_name @@ -419,7 +430,7 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -440,7 +451,7 @@ def load_prepare_datasets( ) else: dataset = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + tokenizer, cfg, default_dataset_prepared_path, split=split ) if cfg.seed: @@ -494,7 +505,7 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - if cfg.val_set_size: + if split == "train" and cfg.val_set_size: # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access @@ -528,6 +539,9 @@ def load_prepare_datasets( train_dataset = dataset["train"] eval_dataset = dataset["test"] + elif split == "test": + train_dataset = None + eval_dataset = dataset else: train_dataset = dataset eval_dataset = None