diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 176ce4174..4ba2d880d 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -78,6 +78,7 @@ from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( + JaggedLRRestartScheduler, get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_warmup_decay_constant, @@ -194,6 +195,22 @@ class AxolotlTrainingMixins: default=0.9, metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, ) + jagged_restart_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for jagged restarts"}, + ) + jagged_restarts_warmup_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many warmup steps to take after reset for jagged restarts" + }, + ) + jagged_restarts_anneal_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many anneal steps to take before reset for jagged restarts" + }, + ) bench_split: Optional[str] = field( default="eval", metadata={"help": "The benchmark split to run on"} ) @@ -415,6 +432,22 @@ def create_scheduler( else: return super().create_scheduler(num_training_steps, optimizer=optimizer) else: + if self.args.jagged_restart_steps: + warmup_steps = ( + self.args.jagged_restarts_warmup_steps or 10 + ) + anneal_steps = ( + self.args.jagged_restarts_anneal_steps or 1 + ) + super().create_scheduler(num_training_steps, optimizer) + self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init + optimizer, + self.lr_scheduler, + self.args.jagged_restart_steps, + warmup_steps, + anneal_steps, + ) + if use_cosine_quadratic: LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index 94387e5ab..51c1dd504 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -217,3 +217,49 @@ def get_cosine_schedule_with_warmup_decay_constant( num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class JaggedLRRestartScheduler(LRScheduler): + """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" + + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + jagged_restarts_steps: int, + jagged_restarts_warmup_steps: int, + jagged_restarts_anneal_steps: int = 1, + min_lr_scale: float = 0.001, + ) -> None: + # pylint: disable=duplicate-code + self.inner_schedule = inner_schedule + self.restarts_steps = jagged_restarts_steps + self.warmup_steps = jagged_restarts_warmup_steps + self.anneal_steps = jagged_restarts_anneal_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) + + def get_lr(self) -> float: + self.inner_schedule.last_epoch = self.last_epoch + + original = self.inner_schedule.get_lr() + step = self.last_epoch + + if step < self.restarts_steps: + scale = 1 + else: + per_restart_progress = step % self.restarts_steps + if per_restart_progress < self.warmup_steps: + cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps) + elif per_restart_progress > (self.restarts_steps - self.anneal_steps): + cycle_t = min( + 1.0, + (self.restarts_steps - per_restart_progress) / self.anneal_steps, + ) + else: + cycle_t = 1 + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + + # if isinstance(original, Sequence): + # return [lr * scale for lr in original] + return original * scale