Skip to content

Commit

Permalink
Scheduler implementation of Continual Pre-Training of Large Language …
Browse files Browse the repository at this point in the history
…Models: How to (re)warm your model? (#1273)
  • Loading branch information
jinwonkim93 authored Feb 13, 2024
1 parent 4b997c3 commit 8430db2
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)

# For one_cycle optim
lr_div_factor: # Learning rate div factor
Expand Down
20 changes: 20 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)

try:
Expand Down Expand Up @@ -164,6 +165,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -221,6 +228,16 @@ def create_scheduler(
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -887,6 +904,9 @@ def build(self, total_num_steps):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
Expand Down
81 changes: 79 additions & 2 deletions src/axolotl/utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
num_cycles: float,
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
Expand Down Expand Up @@ -107,7 +107,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
*,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float
min_lr_ratio: float,
):
# Warm up
if current_step < num_warmup_steps:
Expand Down Expand Up @@ -140,3 +140,80 @@ def get_cosine_schedule_with_min_lr(
min_lr_ratio=min_lr_ratio,
)
return LambdaLR(optimizer, lr_lambda)


def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))

num_constant_steps = int(num_training_steps * constant_lr_ratio)
current_step = min(current_step, num_constant_steps)

progress = float(current_step - num_warmup_steps) / float(
max(1, num_constant_steps - num_warmup_steps)
)

return (
max(
0,
(1 - min_lr_ratio)
* 0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
+ min_lr_ratio
)


def get_cosine_schedule_with_warmup_decay_constant(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
constant_lr_ratio: (`float`):
The ratio of num_training_steps to decrease by cosine function.
min_lr_ratio: (`float):
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

lr_lambda = partial(
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
constant_lr_ratio=constant_lr_ratio,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
52 changes: 52 additions & 0 deletions tests/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
test module for the axolotl.utis.data module
"""
import unittest

import torch
from torch.optim import SGD

from axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant


class TestCosineConstantLr(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""

def setUp(self):
self.train_steps = 1000
self.warmup_steps = 10
self.min_lr_ratio = 0.1
self.constant_lr_ratio = 0.8
self._lr = 0.01
self.optimizer = SGD([torch.tensor(1)], lr=self._lr)
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
self.optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=self.train_steps,
min_lr_ratio=self.min_lr_ratio,
constant_lr_ratio=self.constant_lr_ratio,
)

def test_schedulers(self):
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
for _ in range(self.warmup_steps):
self.lr_scheduler.step()
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
constant_step = int(self.train_steps * self.constant_lr_ratio)
remaining_step = self.train_steps - constant_step
for _ in range(constant_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
for _ in range(remaining_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8430db2

Please sign in to comment.