Skip to content

Commit

Permalink
PoSE context length ext (#1567)
Browse files Browse the repository at this point in the history
* PoSE wip

* fixes for pose splitting

* set pose context len so we can pick that up seperately from the usable training context len

* support min sample len and define num chunks

* fix chunk splitting

* support for curriculum/ordered learning with pose

* fix sequence len sort

* add curriculum_sampling to pydantic
  • Loading branch information
winglian authored Apr 27, 2024
1 parent 98c25e1 commit 5294653
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 5 deletions.
7 changes: 7 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -347,6 +351,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()

def _get_eval_sampler(
Expand Down Expand Up @@ -1193,6 +1199,7 @@ def build(self, total_num_steps):
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = None
if self.cfg.use_wandb:
report_to = "wandb"
Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,17 @@ class Config:
unfrozen_parameters: Optional[List[str]] = None

sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None

# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
pose_max_context_len: Optional[int] = None
pose_num_chunks: Optional[int] = None

pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
Expand Down
108 changes: 103 additions & 5 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Module containing the Trainer class and related functions"""
import math
import os
import random
from contextlib import contextmanager
from functools import partial
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -98,17 +99,89 @@ def add_position_ids(sample):
return sample


def add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids: Optional[List[int]] = None,
chunks: int = 2,
):
"""
use the PoSE technique to extend the context length by randomly skipping
positions in the context. We only want to skip right before tokens in
the split_on_token_ids list. We should attempt to randomly distribute
the skips, but we don't need the final position_ids to be the full
context_len. There may be multiple turns in the context, so we want to
make sure we take into account the maximum possible number of skips
remaining in each sample.
"""

input_ids = sample["input_ids"]
sample_len = len(input_ids)
max_skips = max_context_len - sample_len

if split_on_token_ids is None:
split_on_token_ids = []

if split_on_token_ids:
split_indices = [
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
]
else:
chunk_len = sample_len // chunks
split_indices = [i * chunk_len for i in range(1, chunks)]
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
if split_indices[0] < 2:
# drop the first split index if it's too close to the beginning
split_indices = split_indices[1:]

position_ids = []
prev_index = 0
total_skips = 0

for split_index in split_indices:
num_skips = (
random.randint(0, max_skips) # nosec B311
if prev_index != 0 and max_skips
else 0
)
max_skips -= num_skips
total_skips += num_skips

segment_position_ids = list(
range(prev_index + total_skips, split_index + total_skips)
)

position_ids.extend(segment_position_ids)
prev_index = split_index

sample["sequence_len"] = position_ids[-1]
position_ids = torch.tensor(position_ids)

sample["position_ids"] = position_ids
sample["length"] = len(position_ids)
assert len(position_ids) == len(input_ids)

return sample


def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample


def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
Expand Down Expand Up @@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Group By Length",
)

if cfg.sample_packing:
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
Expand Down

0 comments on commit 5294653

Please sign in to comment.