From 16482796b08da940e249f4238e935d27d8c5f267 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Tue, 27 Feb 2024 00:45:14 +0100 Subject: [PATCH] add lion-pytorch optimizer (#1299) [skip ci] * add lion-pytorch optimizer * update pydantic to support lion optimizer --------- Co-authored-by: Wing Lian --- setup.py | 4 +++ src/axolotl/core/trainer_builder.py | 36 +++++++++++++++---- .../config/models/input/v0_4_1/__init__.py | 2 +- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index aa730fbe69..85d9eae36d 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ def parse_requirements(): or "flash-attention" in line or "deepspeed" in line or "mamba-ssm" in line + or "lion-pytorch" in line ) if line.startswith("--extra-index-url"): # Handle custom index URLs @@ -85,5 +86,8 @@ def parse_requirements(): "mlflow": [ "mlflow", ], + "lion-pytorch": [ + "lion-pytorch==0.1.2", + ], }, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5dca1e2b6b..3502b229cb 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -970,19 +970,43 @@ def build(self, total_num_steps): "neftune_noise_alpha" ] = self.cfg.neftune_noise_alpha - training_args = ( - AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - **training_arguments_kwargs, - ) - ) - training_args = self.hook_post_create_training_args(training_args) trainer_kwargs = {} + if self.cfg.optimizer == "lion_pytorch": + from lion_pytorch import Lion + + lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} + if "weight_decay" in training_arguments_kwargs: + lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] + + if ( + "adam_beta1" in training_arguments_kwargs + and "adam_beta2" in training_arguments_kwargs + ): + lion_kwargs["betas"] = ( + training_arguments_kwargs["adam_beta1"], + training_arguments_kwargs["adam_beta2"], + ) + + trainer_kwargs["optimizers"] = ( + Lion(params=self.model.parameters(), **lion_kwargs), + None, + ) + # Set default so transformers doesn't throw + training_arguments_kwargs["optim"] = "adamw_hf" + if self.cfg.optimizer == "adamw_anyprecision": if Path(self.cfg.torchdistx_path).exists(): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") + training_args = ( + AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg + **training_arguments_kwargs, + ) + ) + training_args = self.hook_post_create_training_args(training_args) + data_collator_kwargs = { "padding": True, # True/"longest" is the default } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f0ab7f235b..e810fc4df7 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -263,7 +263,7 @@ class HyperparametersConfig(BaseModel): learning_rate: Union[str, float] weight_decay: Optional[float] = None - optimizer: Optional[OptimizerNames] = None + optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None torchdistx_path: Optional[str] = None lr_scheduler: Optional[SchedulerType] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None