From 21ec195c9f5561636ecbdba709abcce1e82cc742 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 16 Sep 2023 00:09:48 -0400 Subject: [PATCH] optionally configure sample packing for evals (#589) --- src/axolotl/utils/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 555737f761..ee3e9d2f27 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use sample packing for efficient training."}, ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) sample_packing_efficiency: float = field( default=1.0, metadata={"help": "Sample packing efficiency for calculating batch length."}, @@ -212,7 +216,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_eval_sampler( self, eval_dataset: Dataset ) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size > 1 and self.args.sample_packing: + if ( + self.args.world_size > 1 + and self.args.sample_packing + and self.args.eval_sample_packing is not False + ): return SequentialDistributedSampler( eval_dataset, num_replicas=self.args.world_size, @@ -241,7 +249,7 @@ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataload def get_eval_dataloader( self, eval_dataset: Optional[Dataset] = None ) -> Union[DataLoader, MultipackDistributedDataloader]: - if self.args.sample_packing: + if self.args.sample_packing and self.args.eval_sample_packing is not False: eval_dataset = ( eval_dataset if eval_dataset is not None else self.eval_dataset ) @@ -659,6 +667,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, sample_packing=cfg.sample_packing if cfg.sample_packing else False, + eval_sample_packing=cfg.eval_sample_packing, sample_packing_seq_len_multiplier=cfg.micro_batch_size, relora_steps=cfg.relora_steps, relora_warmup_steps=cfg.relora_warmup_steps,