From a5b93ae106dec21f8a34f9c4fdaa142fd4cb238c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Apr 2024 17:40:55 -0700 Subject: [PATCH] add support for adamw schedulefree --- requirements.txt | 2 + src/axolotl/core/trainer_builder.py | 58 ++++++++++++++++++- .../config/models/input/v0_4_1/__init__.py | 2 +- 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 785ede535e..a872c18c32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,3 +41,5 @@ gcsfs trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f zstandard==0.22.0 + +schedulefree==1.2.1 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cc7275184e..5f095e5a67 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,11 +14,13 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Type, Union +import schedulefree import torch import transformers from datasets import Dataset +from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -27,7 +29,7 @@ TrainerCallback, TrainingArguments, ) -from transformers.trainer_utils import seed_worker +from transformers.trainer_utils import EvalLoopOutput, seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer from trl.trainer.utils import pad_to_length @@ -486,6 +488,31 @@ def compute_loss(self, model, inputs, return_outputs=False): return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs) + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] + ) -> torch.Tensor: + if self.optimizer.__class__.__name__ == "AdamWScheduleFree": + self.optimizer.train() + return super().training_step(model, inputs) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + if self.optimizer.__class__.__name__ == "AdamWScheduleFree": + self.optimizer.eval() + return super().evaluation_loop( + dataloader, + description, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} @@ -1297,6 +1324,33 @@ def build(self, total_num_steps): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") + if self.cfg.optimizer == "schedule_free_adamw": + sfa_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} + if "adam_epsilon" in training_arguments_kwargs: + sfa_kwargs["eps"] = training_arguments_kwargs["adam_epsilon"] + + if "weight_decay" in training_arguments_kwargs: + sfa_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] + + sfa_kwargs["warmup_steps"] = training_arguments_kwargs["warmup_steps"] + + if ( + "adam_beta1" in training_arguments_kwargs + and "adam_beta2" in training_arguments_kwargs + ): + sfa_kwargs["betas"] = ( + training_arguments_kwargs["adam_beta1"], + training_arguments_kwargs["adam_beta2"], + ) + + trainer_kwargs["optimizers"] = ( + schedulefree.AdamWScheduleFree( + params=self.model.parameters(), **sfa_kwargs + ), + None, + ) + training_arguments_kwargs["optim"] = "adamw_hf" + training_args = ( AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index ad332da2da..886de0fe9b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -314,7 +314,7 @@ class HyperparametersConfig(BaseModel): learning_rate: Union[str, float] weight_decay: Optional[float] = 0.0 optimizer: Optional[ - Union[OptimizerNames, Literal["lion_pytorch"]] + Union[OptimizerNames, Literal["lion_pytorch", "schedule_free_adamw"]] ] = OptimizerNames.ADAMW_HF.value optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, metadata={"help": "Optional arguments to supply to optimizer."}