Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jagged lr restart scheudler #1680

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
JaggedLRRestartScheduler,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
Expand Down Expand Up @@ -191,6 +192,22 @@ class AxolotlTrainingMixins:
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restarts_warmup_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restarts_anneal_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many anneal steps to take before reset for jagged restarts"
},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
Expand Down Expand Up @@ -412,6 +429,22 @@ def create_scheduler(
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restarts_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restarts_anneal_steps or 1
)
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
)

if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")

Expand Down
46 changes: 46 additions & 0 deletions src/axolotl/utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,49 @@ def get_cosine_schedule_with_warmup_decay_constant(
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)


class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""

def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restarts_steps: int,
jagged_restarts_warmup_steps: int,
jagged_restarts_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
# pylint: disable=duplicate-code
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restarts_steps
self.warmup_steps = jagged_restarts_warmup_steps
self.anneal_steps = jagged_restarts_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)

def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch

original = self.inner_schedule.get_lr()
step = self.last_epoch

if step < self.restarts_steps:
scale = 1
else:
per_restart_progress = step % self.restarts_steps
if per_restart_progress < self.warmup_steps:
cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale

# if isinstance(original, Sequence):
# return [lr * scale for lr in original]
return original * scale
Loading