From ecdda006deaf1e6b9dbeca091d5f684f83cb5631 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 5 Aug 2024 13:12:05 -0400 Subject: [PATCH] One cycle lr (#1803) * refactor one_cycle lr scheduler so it's reusable in more situations * fix validation for lr_scheduler * default to cosine anneal strategy * one cycle lr exepects cos --- .pre-commit-config.yaml | 2 + src/axolotl/core/trainer_builder.py | 74 ++++++++----------- .../config/models/input/v0_4_1/__init__.py | 2 +- 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c5f205897..9f2ceac56e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,8 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + - id: no-commit-to-branch + args: ['--branch', 'main'] - repo: https://github.com/psf/black rev: 23.3.0 hooks: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cf2866d81d..4e8b369052 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -242,6 +242,12 @@ class AxolotlTrainingMixins: "help": "workaround to pass an alternate optimizer to the HF trainer" }, ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) @dataclass @@ -318,7 +324,23 @@ def create_scheduler( # fmt: off if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition # fmt: on - if use_cosine_quadratic: + if self.args.alternate_lr_scheduler_type == "one_cycle": + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + extra_lr_kwargs = {} + if "pct_start" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["pct_start"] = pct_start + if "anneal_strategy" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["anneal_strategy"] = "cos" + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + **extra_lr_kwargs, + **self.args.lr_scheduler_kwargs, + ) + elif use_cosine_quadratic: if use_cosine_min_lr: LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") @@ -876,37 +898,6 @@ def compute_loss( return lm_loss -class OneCycleLRSchedulerTrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "onecycle"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - pct_start=pct_start, - div_factor=6, - ) - - return self.lr_scheduler - - class ReLoRATrainer(AxolotlTrainer): """ Trainer subclass that uses the OneCycleLR scheduler @@ -1190,10 +1181,6 @@ def get_post_trainer_create_callbacks(self, trainer): return callbacks def _get_trainer_cls(self): - if self.cfg.lr_scheduler == "one_cycle" and ( - self.cfg.fsdp or self.cfg.adapter == "qlora" - ): - return OneCycleLRSchedulerTrainer if self.cfg.relora_steps: return ReLoRATrainer if self.cfg.model_config_type == "mamba": @@ -1443,12 +1430,15 @@ def build(self, total_num_steps): training_arguments_kwargs[ "loraplus_lr_embedding" ] = self.cfg.loraplus_lr_embedding - training_arguments_kwargs["lr_scheduler_type"] = ( - self.cfg.lr_scheduler - if self.cfg.lr_scheduler - and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") - else "cosine" - ) + if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: + training_arguments_kwargs["lr_scheduler_type"] = "cosine" + training_arguments_kwargs[ + "alternate_lr_scheduler_type" + ] = self.cfg.lr_scheduler + else: + training_arguments_kwargs["lr_scheduler_type"] = ( + self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" + ) training_arguments_kwargs["lr_scheduler_kwargs"] = ( self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3b9dbb1a1c..4fb020bd51 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -378,7 +378,7 @@ class HyperparametersConfig(BaseModel): }, ) torchdistx_path: Optional[str] = None - lr_scheduler: Optional[SchedulerType] = "cosine" + lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine" lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_quadratic_warmup: Optional[bool] = None cosine_min_lr_ratio: Optional[float] = None