Skip to content

Commit

Permalink
Enable split_batches through TrainingArguments (#26798)
Browse files Browse the repository at this point in the history
* Enable split_batches through TrainingArguments

* Extra dispatch_batches

* Keep as default false

* Add to docstring

* Add to docstring

* Remove the capturewarnings change

* Comma
  • Loading branch information
muellerzr authored Nov 1, 2023
1 parent 95020f2 commit 3520e37
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3906,6 +3906,7 @@ def create_accelerator_and_postprocess(self):
# create accelerator object
self.accelerator = Accelerator(
dispatch_batches=self.args.dispatch_batches,
split_batches=self.args.split_batches,
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,14 @@ class TrainingArguments:
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
split_batches (`bool`, *optional*):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices
during distributed training. If
set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it
must be a
round multiple of the number of processes you are using (such as GPUs).
include_tokens_per_second (`bool`, *optional*):
Whether or not to compute the number of tokens per second per device for training speed metrics.
Expand Down Expand Up @@ -1226,6 +1234,15 @@ class TrainingArguments:
},
)

split_batches: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If"
"set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a"
"round multiple of the number of processes you are using (such as GPUs)."
},
)

include_tokens_per_second: Optional[bool] = field(
default=False,
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
Expand Down

0 comments on commit 3520e37

Please sign in to comment.