Skip to content

Commit

Permalink
add basic support for the optimi adamw optimizer (#1727)
Browse files Browse the repository at this point in the history
* add support for optimi_adamw optimizer w kahan summation

* pydantic validator for optimi_adamw

* workaround for setting optimizer for fsdp

* make sure to install optimizer packages

* make sure to have parity for model parameters passed to optimizer

* add smoke test for optimi_adamw optimizer

* don't use foreach optimi by default
  • Loading branch information
winglian authored Jul 14, 2024
1 parent 98af538 commit 78e12f8
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 18 deletions.
4 changes: 2 additions & 2 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi

# So we can test the Docker image
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi

# So we can test the Docker image
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,11 @@ def parse_requirements():
"galore": [
"galore_torch",
],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
},
)
65 changes: 55 additions & 10 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)


@dataclass
Expand Down Expand Up @@ -285,25 +291,59 @@ def __init__(
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
):
return super().create_optimizer()

opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)

loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -1396,6 +1436,11 @@ def build(self, total_num_steps):

trainer_kwargs = {}

if self.cfg.optimizer == "optimi_adamw":
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer

if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]]
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_lora(self, temp_dir):
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
Expand All @@ -50,7 +50,7 @@ def test_lora(self, temp_dir):
"type": "alpaca",
},
],
"num_epochs": 2,
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
Expand Down
67 changes: 67 additions & 0 deletions tests/e2e/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
E2E tests for custom optimizers using Llama
"""

import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestCustomOptimizers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""

@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

0 comments on commit 78e12f8

Please sign in to comment.