Skip to content

Commit

Permalink
add bf16 mixed precision support for NPU (#26163)
Browse files Browse the repository at this point in the history
Co-authored-by: statelesshz <[email protected]>
  • Loading branch information
statelesshz and statelesshz authored Sep 27, 2023
1 parent 153755e commit 946bac7
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class TrainingArguments:
eval_accumulation_steps (`int`, *optional*):
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but
left unset, the whole predictions are accumulated on GPU/NPU/TPU before being moved to the CPU (faster but
requires more memory).
eval_delay (`float`, *optional*):
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
Expand Down Expand Up @@ -318,7 +318,7 @@ class TrainingArguments:
installation](https://github.com/intel/intel-extension-for-pytorch).
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture or using CPU (use_cpu). This is an experimental API and it may change.
NVIDIA architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
Expand All @@ -344,7 +344,7 @@ class TrainingArguments:
local_rank (`int`, *optional*, defaults to -1):
Rank of the process during distributed training.
ddp_backend (`str`, *optional*):
The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`.
The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`.
tpu_num_cores (`int`, *optional*):
When training on TPU, the number of TPU cores (automatically passed by launcher script).
dataloader_drop_last (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -855,7 +855,7 @@ class TrainingArguments:
metadata={
"help": (
"Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
" architecture or using CPU (use_cpu). This is an experimental API and it may change."
" architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
)
},
)
Expand Down Expand Up @@ -906,7 +906,7 @@ class TrainingArguments:
default=None,
metadata={
"help": "The backend to be used for distributed training",
"choices": ["nccl", "gloo", "mpi", "ccl"],
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl"],
},
)
tpu_num_cores: Optional[int] = field(
Expand Down Expand Up @@ -1376,6 +1376,15 @@ def __post_init__(self):
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
elif is_torch_npu_available():
# npu
from .pytorch_utils import is_torch_greater_or_equal_than_1_11

if not is_torch_greater_or_equal_than_1_11:
raise ValueError(
"Your setup doesn't support bf16/npu. You need torch>=1.11, using Ascend NPU with "
"`torch_npu` installed"
)
elif not is_torch_xpu_available():
# xpu
from .pytorch_utils import is_torch_greater_or_equal_than_1_12
Expand Down Expand Up @@ -1439,6 +1448,7 @@ def __post_init__(self):
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
Expand All @@ -1447,7 +1457,7 @@ def __post_init__(self):
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX) or CPU/TPU/NeuronCore devices."
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU or CPU/TPU/NeuronCore devices."
)

if self.torchdynamo is not None:
Expand Down

0 comments on commit 946bac7

Please sign in to comment.