From 5294653a2d353066600cbc66bb06f7c63c87147b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 Apr 2024 12:28:20 -0400 Subject: [PATCH] PoSE context length ext (#1567) * PoSE wip * fixes for pose splitting * set pose context len so we can pick that up seperately from the usable training context len * support min sample len and define num chunks * fix chunk splitting * support for curriculum/ordered learning with pose * fix sequence len sort * add curriculum_sampling to pydantic --- src/axolotl/core/trainer_builder.py | 7 ++ .../config/models/input/v0_4_1/__init__.py | 8 ++ src/axolotl/utils/trainer.py | 108 +++++++++++++++++- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bddb95740..09651bdc9b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "path under the model to access the layers"}, ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) class AxolotlTrainer(Trainer): @@ -347,6 +351,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) + if self.args.curriculum_sampling: + return SequentialSampler(self.train_dataset) return super()._get_train_sampler() def _get_eval_sampler( @@ -1193,6 +1199,7 @@ def build(self, total_num_steps): False if self.cfg.ddp else None ) training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length + training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling report_to = None if self.cfg.use_wandb: report_to = "wandb" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d99155ac25..e27a8ddd52 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -503,9 +503,17 @@ class Config: unfrozen_parameters: Optional[List[str]] = None sequence_len: int = Field(default=512) + min_sample_len: Optional[int] = None sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None + curriculum_sampling: Optional[bool] = None + + # for PoSE context length extension + use_pose: Optional[bool] = None + pose_split_on_token_ids: Optional[List[int]] = None + pose_max_context_len: Optional[int] = None + pose_num_chunks: Optional[int] = None pretrain_multipack_buffer_size: Optional[int] = 10_000 pretrain_multipack_attn: Optional[bool] = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 808fbb59f5..2e3728cc8a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,9 +1,10 @@ """Module containing the Trainer class and related functions""" import math import os +import random from contextlib import contextmanager from functools import partial -from typing import List +from typing import List, Optional import numpy as np import torch @@ -98,17 +99,89 @@ def add_position_ids(sample): return sample +def add_pose_position_ids( + sample, + max_context_len=32768, + split_on_token_ids: Optional[List[int]] = None, + chunks: int = 2, +): + """ + use the PoSE technique to extend the context length by randomly skipping + positions in the context. We only want to skip right before tokens in + the split_on_token_ids list. We should attempt to randomly distribute + the skips, but we don't need the final position_ids to be the full + context_len. There may be multiple turns in the context, so we want to + make sure we take into account the maximum possible number of skips + remaining in each sample. + """ + + input_ids = sample["input_ids"] + sample_len = len(input_ids) + max_skips = max_context_len - sample_len + + if split_on_token_ids is None: + split_on_token_ids = [] + + if split_on_token_ids: + split_indices = [ + i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids + ] + else: + chunk_len = sample_len // chunks + split_indices = [i * chunk_len for i in range(1, chunks)] + split_indices.append(len(input_ids)) # make sure we go to the end of the sample + if split_indices[0] < 2: + # drop the first split index if it's too close to the beginning + split_indices = split_indices[1:] + + position_ids = [] + prev_index = 0 + total_skips = 0 + + for split_index in split_indices: + num_skips = ( + random.randint(0, max_skips) # nosec B311 + if prev_index != 0 and max_skips + else 0 + ) + max_skips -= num_skips + total_skips += num_skips + + segment_position_ids = list( + range(prev_index + total_skips, split_index + total_skips) + ) + + position_ids.extend(segment_position_ids) + prev_index = split_index + + sample["sequence_len"] = position_ids[-1] + position_ids = torch.tensor(position_ids) + + sample["position_ids"] = position_ids + sample["length"] = len(position_ids) + assert len(position_ids) == len(input_ids) + + return sample + + def add_length(sample): sample["length"] = len(sample["input_ids"]) return sample -def drop_long_seq(sample, sequence_len=2048): - return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): + return ( + len(sample["input_ids"]) <= sequence_len + and len(sample["input_ids"]) >= min_sequence_len + ) def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + drop_long = partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len or 2, + ) with zero_first(is_main_process()): if cfg.is_preprocess: min_input_len = np.min(get_dataset_lengths(train_dataset)) @@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Group By Length", ) - if cfg.sample_packing: + if cfg.use_pose: + pose_kwargs = {} + if cfg.pose_num_chunks is not None: + pose_kwargs["chunks"] = cfg.pose_num_chunks + pose_fn = partial( + add_pose_position_ids, + max_context_len=cfg.pose_max_context_len, + split_on_token_ids=cfg.pose_split_on_token_ids, + **pose_kwargs, + ) + train_dataset = train_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + train_dataset = train_dataset.sort("sequence_len") + if cfg.eval_sample_packing is not False: + if eval_dataset: + eval_dataset = eval_dataset.map( + pose_fn, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Add position_id column (PoSE)", + ) + elif cfg.sample_packing: train_dataset = train_dataset.map( add_position_ids, num_proc=cfg.dataset_processes,