Skip to content

Commit

Permalink
Improve training args (#25401)
Browse files Browse the repository at this point in the history
* enhanced tips for some training args

* make style
  • Loading branch information
statelesshz authored Aug 9, 2023
1 parent 3deed1f commit 00b93cd
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ class TrainingArguments:
prediction_loss_only (`bool`, *optional*, defaults to `False`):
When performing evaluation and generating predictions, only returns the loss.
per_device_train_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/TPU core/CPU for training.
The batch size per GPU/TPU/MPS/NPU core/CPU for training.
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/TPU core/CPU for evaluation.
The batch size per GPU/TPU/MPS/NPU core/CPU for evaluation.
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
Expand Down Expand Up @@ -648,10 +648,10 @@ class TrainingArguments:
)

per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."}
)

per_gpu_train_batch_size: Optional[int] = field(
Expand Down Expand Up @@ -804,7 +804,9 @@ class TrainingArguments:
)
use_cpu: bool = field(
default=False,
metadata={"help": " Whether or not to use cpu. If set to False, we will use cuda or mps device if available."},
metadata={
"help": " Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available."
},
)
use_mps_device: bool = field(
default=False,
Expand Down

0 comments on commit 00b93cd

Please sign in to comment.