Skip to content

Commit

Permalink
streaming multipack for pretraining dataset (#959)
Browse files Browse the repository at this point in the history
* [Feat] streaming multipack

* WIP make continued pretraining work w multipack

* fix up hadrcoding, lint

* fix dict check

* update test for updated pretraining multipack code

* fix hardcoded data collator fix for multipack pretraining

* fix the collator to be the max length for multipack pretraining

* don't bother with latest tag for test

* cleanup docker build/test

---------

Co-authored-by: [email protected] <jinwonkim>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
jinwonkim93 and winglian authored Jan 6, 2024
1 parent eb4c994 commit 553c80f
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 12 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/tests-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
Expand All @@ -37,7 +36,7 @@ jobs:
images: winglian/axolotl
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and export to Docker
- name: Build Docker image
uses: docker/build-push-action@v5
with:
context: .
Expand All @@ -49,8 +48,7 @@ jobs:
file: ./docker/Dockerfile
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
- name: Unit Tests w docker image
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
58 changes: 58 additions & 0 deletions examples/tiny-llama/pretrain.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

max_steps: 200
pretraining_dataset:
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

sequence_len: 2048
sample_packing: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
18 changes: 14 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
Expand Down Expand Up @@ -157,7 +163,7 @@ def create_scheduler(
return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
Expand Down Expand Up @@ -193,7 +199,7 @@ def _get_eval_sampler(
return super()._get_eval_sampler(eval_dataset)

def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
Expand Down Expand Up @@ -768,6 +774,7 @@ def build(self, total_num_steps):
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)

if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
Expand Down Expand Up @@ -808,7 +815,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(**data_collator_kwargs),
data_collator=self.build_collator(training_args, **data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -829,7 +836,10 @@ def build(self, total_num_steps):

return trainer

def build_collator(self, **kwargs):
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
if training_args.pretraining:
return None

if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
"input_ids": input_ids,
"labels": labels,
}


@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""

def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
99 changes: 95 additions & 4 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union

Expand All @@ -14,6 +15,7 @@
load_from_disk,
)
from huggingface_hub import hf_hub_download
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
Expand All @@ -39,11 +41,14 @@
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
process_pretraining_datasets_for_packing,
)

LOG = logging.getLogger("axolotl")
Expand All @@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]

train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
path,
tokenizer,
cfg,
name=name,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
Expand Down Expand Up @@ -806,9 +819,27 @@ def encode_pretraining(
return ret


def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
Expand All @@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
remove_columns=dataset.features.keys(),
)
return dataset


def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)

input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]

tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)

sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
)

chunked_data = defaultdict(list)

for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)

for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))

return chunked_data
10 changes: 10 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset


def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)

train_dataset = train_dataset.filter(drop_long)
train_dataset = train_dataset.map(
add_position_ids,
)
return train_dataset


def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
Expand Down
Loading

0 comments on commit 553c80f

Please sign in to comment.