Skip to content

Commit

Permalink
support for explicit test_dataset definition for evals
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 2, 2024
1 parent 4d2e842 commit fa4f187
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ def validate_config(cfg):
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
)

if cfg.test_datasets and cfg.val_set_size:
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
75 changes: 43 additions & 32 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from datasets import (
Expand Down Expand Up @@ -60,9 +60,17 @@ def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
Expand Down Expand Up @@ -98,8 +106,12 @@ def prepare_dataset(cfg, tokenizer):


def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
tokenizer,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
Expand All @@ -110,7 +122,7 @@ def load_tokenized_prepared_datasets(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets
for d in cfg_datasets
]
)
)
Expand All @@ -133,7 +145,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec
pass

Expand Down Expand Up @@ -164,8 +176,8 @@ def for_d_in_datasets(dataset_configs):
yield dataset

# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
load_dataset(
Expand Down Expand Up @@ -318,34 +330,28 @@ def for_d_in_datasets(dataset_configs):
)
if not ds:
raise ValueError("unhandled dataset load")
# support for using a subset of the data
if config_dataset.shards:
if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0
)
else:
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=0
)

d_base_type = d_prompt_style = None
d_type = config_dataset.type
if isinstance(d_type, str):
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
elif (
isinstance(ds, DatasetDict)
and config_dataset.train_on_split
and config_dataset.train_on_split in ds
):
ds = ds[config_dataset.train_on_split]

if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)

# support for using a subset of the data
if config_dataset.shards:
shards_idx = config_dataset.get("shards_idx", 0)
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=shards_idx
)

dataset_wrapper, dataset_prompter = get_dataset_wrapper(
Expand Down Expand Up @@ -401,7 +407,9 @@ def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]:
cfg_datasets = cfg.eval_datasets if split != "train" else cfg.datasets
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
Expand All @@ -422,7 +430,7 @@ def load_prepare_datasets(
+ str(max_packed_sequence_len)
+ seed
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg_datasets])
)
+ "|"
+ tokenizer_name
Expand All @@ -446,7 +454,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec
pass

Expand All @@ -467,7 +475,7 @@ def load_prepare_datasets(
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
tokenizer, cfg, default_dataset_prepared_path, split=split
)

if cfg.seed:
Expand Down Expand Up @@ -521,7 +529,7 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)

if cfg.val_set_size:
if split == "train" and cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
Expand Down Expand Up @@ -554,6 +562,9 @@ def load_prepare_datasets(

train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
train_dataset = None
eval_dataset = dataset
else:
train_dataset = dataset
eval_dataset = None
Expand Down

0 comments on commit fa4f187

Please sign in to comment.