diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b9e10376157457..72b2cc049a174b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1068,6 +1068,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") + elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED: + try: + from torch_npu.optim import NpuFusedAdamW + + optimizer_cls = NpuFusedAdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import FusedAdamW from torch_npu.") elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: try: from apex.optimizers import FusedAdam diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 07e3d04ef91075..7c016b15b2e648 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -140,6 +140,7 @@ class OptimizerNames(ExplicitEnum): ADAMW_TORCH = "adamw_torch" ADAMW_TORCH_FUSED = "adamw_torch_fused" ADAMW_TORCH_XLA = "adamw_torch_xla" + ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused" ADAMW_APEX_FUSED = "adamw_apex_fused" ADAFACTOR = "adafactor" ADAMW_ANYPRECISION = "adamw_anyprecision"