Skip to content

Commit

Permalink
optionally configure sample packing for evals (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Sep 16, 2023
1 parent 62eaee7 commit 21ec195
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
Expand Down Expand Up @@ -212,7 +216,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
Expand Down Expand Up @@ -241,7 +249,7 @@ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataload
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
Expand Down Expand Up @@ -659,6 +667,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
eval_sample_packing=cfg.eval_sample_packing,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps,
Expand Down

0 comments on commit 21ec195

Please sign in to comment.