Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for explicit test_dataset definition for evals #786

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ def validate_config(cfg):
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
)

if cfg.test_datasets and cfg.val_set_size:
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
68 changes: 39 additions & 29 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from datasets import (
Expand Down Expand Up @@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
Expand Down Expand Up @@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer):


def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
tokenizer,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
Expand All @@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets
for d in cfg_datasets
]
)
)
Expand All @@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec
pass

Expand Down Expand Up @@ -188,8 +200,8 @@ def for_d_in_datasets(dataset_configs):
yield dataset

# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
load_dataset(
Expand Down Expand Up @@ -342,34 +354,28 @@ def for_d_in_datasets(dataset_configs):
)
if not ds:
raise ValueError("unhandled dataset load")
# support for using a subset of the data
if config_dataset.shards:
if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0
)
else:
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=0
)

d_base_type = d_prompt_style = None
d_type = config_dataset.type
if isinstance(d_type, str):
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
elif (
isinstance(ds, DatasetDict)
and config_dataset.train_on_split
and config_dataset.train_on_split in ds
):
ds = ds[config_dataset.train_on_split]

if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)

# support for using a subset of the data
if config_dataset.shards:
shards_idx = config_dataset.get("shards_idx", 0)
ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=shards_idx
)

dataset_wrapper, dataset_prompter = get_dataset_wrapper(
Expand Down Expand Up @@ -428,6 +434,7 @@ def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
Expand All @@ -442,7 +449,7 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)

if cfg.val_set_size:
if split == "train" and cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
Expand Down Expand Up @@ -475,6 +482,9 @@ def load_prepare_datasets(

train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
train_dataset = None
eval_dataset = dataset
else:
train_dataset = dataset
eval_dataset = None
Expand Down