Skip to content

Commit

Permalink
Extend Trainer to enable Ascend NPU to use the fused Adamw optimizer …
Browse files Browse the repository at this point in the history
…when training (#26194)
  • Loading branch information
statelesshz authored Oct 4, 2023
1 parent fc296f4 commit 4fdf47c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,14 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
try:
from torch_npu.optim import NpuFusedAdamW

optimizer_cls = NpuFusedAdamW
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
try:
from apex.optimizers import FusedAdam
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
Expand Down

0 comments on commit 4fdf47c

Please sign in to comment.