From 05bd6f11225009cc35638922fca301db0d2e586a Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 26 Oct 2023 07:49:52 +0200 Subject: [PATCH] Threaded MultipackDistributedDataloader with prefetched samples (#759) * Multithreading implementation [WIP] * Added benchmarking * 35% increased throughput * Memory pinning * Start threads in init * Correct print of samples * Sleep if queue is full * Remove pin_memory (worse) * Simplify logic to one thread * Remove benchmark * Use deque for constant speed * Formatting * Formatting * Formatting * Formatting * Rollback to use queue * Fix multi-epoch training * Add num epochs arg * Start thread in __iter__ * Formatting * Use is_alive correctly * Simplify loading thread --- src/axolotl/core/trainer_builder.py | 6 +++- src/axolotl/utils/dataloader.py | 50 ++++++++++++++++++++++++++--- src/axolotl/utils/trainer.py | 1 + 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 00a1a0c670..55a1764fcd 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer): args = None # type: AxolotlTrainingArguments - def __init__(self, *args, bench_data_collator=None, **kwargs): + def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): + self.num_epochs = num_epochs self.bench_data_collator = bench_data_collator super().__init__(*args, **kwargs) @@ -182,6 +183,7 @@ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataload packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=self.num_epochs, ) ) return super().get_train_dataloader() @@ -205,6 +207,7 @@ def get_eval_dataloader( packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.eval_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=self.num_epochs, ) ) return super().get_eval_dataloader(eval_dataset) @@ -680,6 +683,7 @@ def build(self, total_num_steps): **data_collator_kwargs, ), callbacks=self.get_callbacks(), + num_epochs=self.cfg.num_epochs, **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index d659c3d334..54c95db787 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -3,6 +3,9 @@ import itertools import logging import math +import time +from queue import Queue +from threading import Thread from typing import Any, Callable, List, Union import numba @@ -149,6 +152,8 @@ def __init__( packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, + prefetch_max: int = 1000, + num_epochs: int = 1, ): # Dataset self.dataset = dataset @@ -167,6 +172,7 @@ def __init__( self.seq_max_length = seq_max_length self.batch_max_length = batch_size * seq_max_length self.collate_fn = collate_fn + self.num_epochs = num_epochs self.num_replicas = 1 self.rank = 0 @@ -177,6 +183,44 @@ def __init__( self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.device_count = device_count + # maxsize is maximum number of samples in queue + self.prefetch_max = prefetch_max + self.queue: Queue = Queue(maxsize=prefetch_max) + self.thread = None + + def _worker(self): + LOG.info( + f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}" + ) + for epoch in range(self.num_epochs): + for sample in self._internal_batch_generator(): + while True: + if self.queue.full(): + time.sleep(1) + else: + break + self.queue.put(sample) + + # stop the queue when epoch is done + self.queue.put(None) + + def __iter__(self): + if hasattr(self.sampler, "set_epoch"): + new_epoch = self.sampler.epoch + 1 + self.sampler.set_epoch(new_epoch) + LOG.info(f"calling sampler.set_epoch({new_epoch})") + + if self.thread is None: + self.thread = Thread(target=self._worker, daemon=True) + self.thread.start() + + while True: + item = self.queue.get() + + if item is None: + break + yield item + def generate_batches(self, set_stats=False): LOG.info("generating packed batches") if self.sampler: @@ -206,11 +250,7 @@ def generate_batches(self, set_stats=False): return batches, totseqs - def __iter__(self): - if hasattr(self.sampler, "set_epoch"): - new_epoch = self.sampler.epoch + 1 - self.sampler.set_epoch(new_epoch) - LOG.info(f"calling sampler.set_epoch({new_epoch})") + def _internal_batch_generator(self): all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() len_remaining = self._len_est() diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0d275cbf55..d04390293e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): packing_efficiency_estimate=cfg.sample_packing_eff_est, sample_packing_seq_len_multiplier=cfg.micro_batch_size, device_count=int(os.environ.get("WORLD_SIZE", 1)), + num_epochs=cfg.num_epochs, ) data_loader_len = data_loader.len_w_stats() actual_eff = data_loader.efficiency()