Skip to content

Commit

Permalink
ADOPT optimizer integration (#2032) [skip ci]
Browse files Browse the repository at this point in the history
* adopt integration

* stuff

* doc and test for ADOPT

* rearrangement

* fixed formatting

* hacking pre-commit

* chore: lint

* update module doc for adopt optimizer

* remove un-necessary example yaml for adopt optimizer

* skip test adopt if torch<2.5.1

* formatting

* use version.parse

* specifies required torch version for adopt_adamw

---------

Co-authored-by: sunny <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent 659ee5d commit 1d7aee0
Show file tree
Hide file tree
Showing 6 changed files with 588 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd
Expand Down
18 changes: 17 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,13 @@ def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()

Expand Down Expand Up @@ -505,6 +511,14 @@ def create_optimizer(self):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
)
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -1625,11 +1639,13 @@ def build(self, total_num_steps):
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len

# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF.value
Expand Down
Loading

0 comments on commit 1d7aee0

Please sign in to comment.