From 3520e37e86913715959ff14fef76340010c8de57 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 14:42:38 -0400 Subject: [PATCH] Enable split_batches through TrainingArguments (#26798) * Enable split_batches through TrainingArguments * Extra dispatch_batches * Keep as default false * Add to docstring * Add to docstring * Remove the capturewarnings change * Comma --- src/transformers/trainer.py | 1 + src/transformers/training_args.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index aa5e372bdc2..5def3ca8904 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3906,6 +3906,7 @@ def create_accelerator_and_postprocess(self): # create accelerator object self.accelerator = Accelerator( dispatch_batches=self.args.dispatch_batches, + split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, ) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 147d1e6b1c6..aaedc83528a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -621,6 +621,14 @@ class TrainingArguments: Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions. This flag is experimental and subject to change in future releases. + split_batches (`bool`, *optional*): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices + during distributed training. If + + set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it + must be a + + round multiple of the number of processes you are using (such as GPUs). include_tokens_per_second (`bool`, *optional*): Whether or not to compute the number of tokens per second per device for training speed metrics. @@ -1226,6 +1234,15 @@ class TrainingArguments: }, ) + split_batches: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If" + "set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a" + "round multiple of the number of processes you are using (such as GPUs)." + }, + ) + include_tokens_per_second: Optional[bool] = field( default=False, metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},