From 00b93cda213aa3e754ebbec6ccda77cca4b67247 Mon Sep 17 00:00:00 2001 From: Alan Ji Date: Wed, 9 Aug 2023 19:50:13 +0800 Subject: [PATCH] Improve training args (#25401) * enhanced tips for some training args * make style --- src/transformers/training_args.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1f6e81959ec16f..95732bbfc551c7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -193,9 +193,9 @@ class TrainingArguments: prediction_loss_only (`bool`, *optional*, defaults to `False`): When performing evaluation and generating predictions, only returns the loss. per_device_train_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/TPU core/CPU for training. + The batch size per GPU/TPU/MPS/NPU core/CPU for training. per_device_eval_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/TPU core/CPU for evaluation. + The batch size per GPU/TPU/MPS/NPU core/CPU for evaluation. gradient_accumulation_steps (`int`, *optional*, defaults to 1): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. @@ -648,10 +648,10 @@ class TrainingArguments: ) per_device_train_batch_size: int = field( - default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} + default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."} ) per_device_eval_batch_size: int = field( - default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} + default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."} ) per_gpu_train_batch_size: Optional[int] = field( @@ -804,7 +804,9 @@ class TrainingArguments: ) use_cpu: bool = field( default=False, - metadata={"help": " Whether or not to use cpu. If set to False, we will use cuda or mps device if available."}, + metadata={ + "help": " Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available." + }, ) use_mps_device: bool = field( default=False,