From 04b978b4283636a88ce6e54c4ed78b441b381f56 Mon Sep 17 00:00:00 2001 From: Ricardo Dominguez-Olmedo Date: Tue, 9 Jan 2024 15:29:56 +0100 Subject: [PATCH] Cosine learning rate schedule - minimum learning rate (#1062) * Cosine min lr * Cosine min lr - warn if using deepspeed * cosine_min_lr_ratio readme * chore: lint --------- Co-authored-by: Wing Lian --- README.md | 1 + src/axolotl/core/trainer_builder.py | 21 ++++++++++++++- src/axolotl/utils/schedulers.py | 40 +++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 30465e316f..7e53da2e1e 100644 --- a/README.md +++ b/README.md @@ -755,6 +755,7 @@ early_stopping_patience: 3 # Specify a scheduler and kwargs to use with the optimizer lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler_kwargs: +cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr # For one_cycle optim lr_div_factor: # Learning rate div factor diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cc3c73c9f8..f0d1c4343b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -38,7 +38,10 @@ MambaDataCollator, ) from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup +from axolotl.utils.schedulers import ( + get_cosine_schedule_with_min_lr, + get_cosine_schedule_with_quadratic_warmup, +) try: import torch._dynamo # pylint: disable=ungrouped-imports @@ -120,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "prefetch_factor argument to the dataloader"}, ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) class AxolotlTrainer(Trainer): @@ -159,6 +166,17 @@ def create_scheduler( num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, ) + elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + if self.args.deepspeed: + LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \ + in the deepspeed JSON") + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) else: return super().create_scheduler(num_training_steps, optimizer) return self.lr_scheduler @@ -745,6 +763,7 @@ def build(self, total_num_steps): training_arguments_kwargs["lr_scheduler_kwargs"] = ( self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) + training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index 4c14a358a3..c49745c263 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -100,3 +100,43 @@ def get_cosine_schedule_with_quadratic_warmup( num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_schedule_with_min_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float +): + # Warm up + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + # Cosine learning rate decay + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + scaling = 0.5 * (1.0 + math.cos(math.pi * progress)) + return (1 - min_lr_ratio) * scaling + min_lr_ratio + + +def get_cosine_schedule_with_min_lr( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float = 0.0, +): + """ + Create a learning rate schedule which has: + - linear warmup from 0 -> `max_lr` over `num_warmup_steps` + - cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps` + """ + + lr_lambda = partial( + _get_cosine_schedule_with_min_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + min_lr_ratio=min_lr_ratio, + ) + return LambdaLR(optimizer, lr_lambda)