diff --git a/gitbook/README.md b/gitbook/README.md deleted file mode 100644 index 642bde22ae..0000000000 --- a/gitbook/README.md +++ /dev/null @@ -1 +0,0 @@ -# Page diff --git a/gitbook/SUMMARY.md b/gitbook/SUMMARY.md deleted file mode 100644 index acfb2d0100..0000000000 --- a/gitbook/SUMMARY.md +++ /dev/null @@ -1,4 +0,0 @@ -# Table of contents - -* [Page](README.md) -* [Small dev details](small-dev-details.md) diff --git a/gitbook/small-dev-details.md b/gitbook/small-dev-details.md deleted file mode 100644 index f5eddf405b..0000000000 --- a/gitbook/small-dev-details.md +++ /dev/null @@ -1,3 +0,0 @@ -# Small dev details - -/ diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 55a1764fcd..2269880375 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -6,7 +6,6 @@ import importlib import logging import math -import os import sys from abc import abstractmethod from dataclasses import dataclass, field @@ -18,9 +17,9 @@ import transformers from datasets import Dataset from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments -from transformers.trainer_pt_utils import SequentialDistributedSampler +from transformers.trainer_utils import seed_worker from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( @@ -31,8 +30,9 @@ bench_eval_callback_factory, log_prediction_callback_factory, ) -from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader +from axolotl.utils.samplers import MultipackBatchSampler from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup try: @@ -102,6 +102,10 @@ class AxolotlTrainingArguments(TrainingArguments): bench_source_max_len: int = field( default=2048, metadata={"help": "Maximum source sequence length for bench."} ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) class AxolotlTrainer(Trainer): @@ -145,46 +149,69 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size > 1 and self.args.sample_packing: - return DistributedSampler( - self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=self.args.seed, + if self.args.sample_packing: + return MultipackBatchSampler( + RandomSampler(self.train_dataset), + self.args.train_batch_size, + drop_last=True, + batch_max_len=self._train_batch_size * self.args.max_seq_length, + lengths=( + self.train_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ), + packing_efficiency_estimate=self.args.sample_packing_efficiency, ) return super()._get_train_sampler() def _get_eval_sampler( self, eval_dataset: Dataset ) -> Optional[torch.utils.data.Sampler]: - if ( - self.args.world_size > 1 - and self.args.sample_packing - and self.args.eval_sample_packing is not False - ): - return SequentialDistributedSampler( - eval_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - batch_size=self.args.per_device_eval_batch_size, + if self.args.sample_packing and self.args.eval_sample_packing is not False: + return MultipackBatchSampler( + SequentialSampler(eval_dataset), + self.args.per_device_eval_batch_size, + drop_last=True, + batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, + lengths=( + eval_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ), + packing_efficiency_estimate=self.args.sample_packing_efficiency, ) return super()._get_eval_sampler(eval_dataset) - def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: + def get_train_dataloader(self) -> DataLoader: if self.args.sample_packing: - train_sampler = self._get_train_sampler() - return self.accelerator.prepare( - MultipackDistributedDataloader( - self.train_dataset, - batch_size=self._train_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=train_sampler, - packing_efficiency_estimate=self.args.sample_packing_efficiency, - sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, - device_count=int(os.environ.get("WORLD_SIZE", 1)), - num_epochs=self.num_epochs, - ) + train_dataset = self.train_dataset + train_dataset = train_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + sampler = self._get_train_sampler() + if isinstance(sampler, BatchSampler): + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(train_dataset, **dataloader_params) ) return super().get_train_dataloader() @@ -197,18 +224,29 @@ def get_eval_dataloader( ) eval_sampler = self._get_eval_sampler(eval_dataset) - return self.accelerator.prepare( - MultipackDistributedDataloader( - eval_dataset, - batch_size=self.args.eval_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=eval_sampler, - packing_efficiency_estimate=self.args.sample_packing_efficiency, - sample_packing_seq_len_multiplier=self.args.eval_batch_size, - device_count=int(os.environ.get("WORLD_SIZE", 1)), - num_epochs=self.num_epochs, - ) + eval_dataset = eval_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + if isinstance(eval_sampler, BatchSampler): + dataloader_params["batch_sampler"] = eval_sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = eval_sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(eval_dataset, **dataloader_params) ) return super().get_eval_dataloader(eval_dataset) @@ -229,6 +267,8 @@ def get_bench_dataloader( "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } + if self.args.dataloader_prefetch_factor: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if not isinstance(bench_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) @@ -493,6 +533,19 @@ def build(self, total_num_steps): "sample_packing_efficiency" ] = self.cfg.sample_packing_eff_est + if self.cfg.dataloader_pin_memory is not None: + training_arguments_kwargs[ + "dataloader_pin_memory" + ] = self.cfg.dataloader_pin_memory + if self.cfg.dataloader_num_workers is not None: + training_arguments_kwargs[ + "dataloader_num_workers" + ] = self.cfg.dataloader_num_workers + if self.cfg.dataloader_prefetch_factor is not None: + training_arguments_kwargs[ + "dataloader_prefetch_factor" + ] = self.cfg.dataloader_prefetch_factor + if self.cfg.eval_steps: training_arguments_kwargs["evaluation_strategy"] = "steps" training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps @@ -672,7 +725,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=DataCollatorForSeq2Seq( + data_collator=BatchSamplerDataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", **data_collator_kwargs, @@ -690,4 +743,9 @@ def build(self, total_num_steps): for callback in self.get_post_trainer_create_callbacks(trainer): trainer.add_callback(callback) + if self.cfg.deepspeed and self.cfg.sample_packing: + trainer.accelerator.state.deepspeed_plugin.deepspeed_config[ + "train_micro_batch_size_per_gpu" + ] = self.cfg.micro_batch_size + return trainer diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index d7acdc9776..ffae3f2631 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -119,3 +119,30 @@ def __call__(self, features, return_tensors=None): features["decoder_input_ids"] = decoder_input_ids return features + + +@dataclass +class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features[0].keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (1) * np.array(item[feature]) + for item in features + if feature in item + ] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) for item in features if feature in item + ] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 697c26baaf..fd41d2c61b 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -80,11 +80,11 @@ def prepare_dataset(cfg, tokenizer): ) if cfg.max_steps: total_num_steps = min( - calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps + calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) LOG.info(f"Maximum number of steps set at {total_num_steps}") else: - total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + total_num_steps = calculate_total_num_steps(cfg, train_dataset) return train_dataset, eval_dataset, total_num_steps, prompters diff --git a/src/axolotl/utils/samplers/__init__.py b/src/axolotl/utils/samplers/__init__.py new file mode 100644 index 0000000000..4c102826f8 --- /dev/null +++ b/src/axolotl/utils/samplers/__init__.py @@ -0,0 +1,4 @@ +""" +axolotl samplers module +""" +from .multipack import MultipackBatchSampler # noqa: F401 diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py new file mode 100644 index 0000000000..e576320828 --- /dev/null +++ b/src/axolotl/utils/samplers/multipack.py @@ -0,0 +1,193 @@ +# pylint: skip-file +""" +Multipack Batch Sampler +""" +import logging +import math +import os +from typing import Any, Iterable, List, Union + +import numba +import numpy as np +from torch.utils.data import BatchSampler, Sampler + +LOG = logging.getLogger("axolotl.utils.samplers.multipack") + + +@numba.njit +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins: List[Any] = [] + bins_result: List[Any] = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + left = 1 + right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while right - left > 1: + mid = (left + right) // 2 + if ffd_check(lengths[start_index : start_index + mid], c, n): + left = mid + else: + right = mid + + # use length l + batch = ffd_with_result( + lengths[start_index : start_index + left], c, start_index + ) + assert len(batch) <= n + if len(batch) < n: + break + + start_index += left + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + + return result, s, len(result) * c * n + + +class MultipackBatchSampler(BatchSampler): + """ + Batch Sampler class for multipack + """ + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + batch_max_len: int, + lengths: np.ndarray, + packing_efficiency_estimate: float = 1.0, + ): + super().__init__(sampler, batch_size, drop_last) + self.batch_size = None + self.batch_max_len = batch_max_len + self.lengths: np.ndarray = lengths + self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + + assert isinstance(self.lengths, np.ndarray) + + self.epoch = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = [idx for idx in self.sampler] + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=0, + c=self.batch_max_len, + n=1, + ) + + batches = [[indices[b_idx] for b_idx in batch] for batch in batches] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches + + def __iter__(self): + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def num_batches(self): + batches = self.generate_batches(set_stats=True) + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots + + def __len__(self): + self.num_batches() + return self._len_est() + + def _len_est(self): + world_size = int(os.getenv("WORLD_SIZE", "1")) + lengths_sum = np.sum(self.lengths) + lengths_sum_per_device = lengths_sum // world_size + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"total_num_tokens per device: {lengths_sum_per_device}" + ) + + # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler + return ( + world_size + * math.floor( + 0.99 + * lengths_sum_per_device + / self.packing_efficiency_estimate + // self.batch_max_len + ) + - 1 + ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6bcf0dac11..f93316cde8 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -8,20 +8,13 @@ import numpy as np import torch import torch.cuda -import torch.distributed as dist from accelerate.logging import get_logger from datasets import set_caching_enabled -from torch.utils.data import DistributedSampler, RandomSampler +from torch.utils.data import DataLoader, RandomSampler from axolotl.core.trainer_builder import HFCausalTrainerBuilder -from axolotl.utils.collators import DataCollatorForSeq2Seq -from axolotl.utils.dataloader import MultipackDistributedDataloader -from axolotl.utils.distributed import ( - is_distributed, - is_main_process, - reduce_and_broadcast, - zero_first, -) +from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first +from axolotl.utils.samplers import MultipackBatchSampler LOG = get_logger("axolotl") @@ -148,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): return train_dataset, eval_dataset -def calculate_total_num_steps(cfg, train_dataset, tokenizer): +def calculate_total_num_steps(cfg, train_dataset): if cfg.sample_packing: # we have to drop anything longer then sequence len otherwise # flash attention with position ids fails @@ -196,37 +189,36 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): main_process_only=True, ) else: - if cfg.world_size > 1 and is_distributed(): - sampler = DistributedSampler( - train_dataset, - num_replicas=cfg.world_size, - rank=dist.get_rank(), - seed=cfg.seed or 42, - ) - else: - sampler = RandomSampler(train_dataset) - - data_loader = MultipackDistributedDataloader( - train_dataset, + sampler = MultipackBatchSampler( + sampler=RandomSampler(train_dataset), batch_size=cfg.micro_batch_size, - seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, - collate_fn=DataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - padding="longest", + drop_last=True, + batch_max_len=cfg.micro_batch_size + * (cfg.max_packed_sequence_len or cfg.sequence_len), + lengths=( + train_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values ), - sampler=sampler, - packing_efficiency_estimate=cfg.sample_packing_eff_est, - sample_packing_seq_len_multiplier=cfg.micro_batch_size, - device_count=int(os.environ.get("WORLD_SIZE", 1)), - num_epochs=cfg.num_epochs, ) - data_loader_len = data_loader.len_w_stats() - actual_eff = data_loader.efficiency() + + data_loader = DataLoader( + train_dataset.remove_columns(["length"]), + batch_sampler=sampler, + ) + data_loader_len = len(data_loader) + actual_eff = sampler.efficiency() LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est - total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) + total_num_steps = int( + math.floor( + data_loader_len + * cfg.num_epochs + / int(os.environ.get("WORLD_SIZE", 1)) + ) + ) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") @@ -246,7 +238,12 @@ def calc_sample_packing_eff_est(estimates: List[float]): ) else: total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + math.ceil( + len(train_dataset) + * cfg.num_epochs + / int(os.environ.get("WORLD_SIZE", 1)) + / cfg.batch_size + ) ) LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) return total_num_steps