Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretrain transforms #1261

Merged
merged 9 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/tiny-llama/pretrain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ max_steps: 200
pretraining_dataset:
path: c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
Expand Down
58 changes: 58 additions & 0 deletions src/axolotl/prompt_strategies/pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""pretraining prompt strategies"""
from typing import Generator

from transformers import BatchEncoding

from axolotl.prompt_tokenizers import PromptTokenizingStrategy


class PretrainTokenizer:
"""basic tokenization class for pretraining"""

def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt


class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""

@property
def supports_batched(self):
return True

def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length

def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]

return res

def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])


def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat
85 changes: 49 additions & 36 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import yaml
Expand Down Expand Up @@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]

train_dataset = load_pretraining_dataset(
path,
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer,
cfg,
name=name,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)

train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
Expand Down Expand Up @@ -383,9 +392,9 @@ def for_d_in_datasets(dataset_configs):

dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
Expand Down Expand Up @@ -496,7 +505,12 @@ def load_prepare_datasets(


def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
config_dataset,
tokenizer,
cfg,
d_base_type,
dataset,
d_prompt_style=None,
):
dataset_wrapper = None
dataset_prompter = None
Expand All @@ -507,7 +521,8 @@ def get_dataset_wrapper(
}

if (
"input_ids" in dataset.features
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
Expand Down Expand Up @@ -765,69 +780,60 @@ def encode_pretraining(
return ret


def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
pad_to_multiple_of=max_tokens * batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
batch_size=batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
batch_size=buffer_size,
# input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
desc="Encoding Pretraining",
)
return dataset


def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)

input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]

tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]

train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
Expand All @@ -845,7 +851,14 @@ def encode_packed_pretraining(
for batch in sampler:
for data in batch:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)

for feature in features.keys():
Expand Down
61 changes: 36 additions & 25 deletions tests/test_packed_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import unittest
from functools import partial

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault


class TestPretrainingPacking(unittest.TestCase):
Expand All @@ -20,8 +20,6 @@ def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2

def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
Expand All @@ -31,30 +29,43 @@ def test_packing_stream_dataset(self):
streaming=True,
)["train"]

collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
cfg = DictDefault(
{
"pretraining_dataset": [
{
"path": "c4",
"name": "en",
"type": "pretrain",
}
],
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
}
)

encode = partial(
encode_packed_pretraining,
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
self.tokenizer,
collate_fn,
max_seq_length=self.max_seq_length,
batch_size=self.batch_size,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)

dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
self.tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)

trainer_loader = DataLoader(
dataset,
train_dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
Expand All @@ -64,16 +75,16 @@ def test_packing_stream_dataset(self):
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
assert data["position_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
assert data["labels"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
assert data["attention_mask"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
[1, original_bsz * cfg.sequence_len]
)
idx += 1

Expand Down
Loading