Skip to content

Commit

Permalink
Inverse Square Root LR Schedule (#657)
Browse files Browse the repository at this point in the history
* Implement inverse square root with warmup scheduler v0

* inverse square root LR Schedule

* scheduler

* unit tests

* Update llmfoundry/optim/scheduler.py

Co-authored-by: Brian <[email protected]>

* Update llmfoundry/optim/scheduler.py

Co-authored-by: Brian <[email protected]>

* fixes for PR conversations

* format

* fix type hint

---------

Co-authored-by: Brian <[email protected]>
Co-authored-by: cody <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2023
1 parent bdac4c7 commit 6c98276
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 0 deletions.
153 changes: 153 additions & 0 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Experimental learning rate schedulers used for training LLMs."""

import textwrap
import warnings
from typing import Union

from composer.core import State, Time, TimeUnit
from composer.optim import ComposerScheduler, LinearScheduler
from composer.optim.scheduler import _convert_time

__all__ = ['InverseSquareRootWithWarmupScheduler']


def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)
if time.unit != t_max.unit:
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.')


def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
if time.unit == TimeUnit('dur'):
raise ValueError(f'{name} cannot be in units of "dur".')


class InverseSquareRootWithWarmupScheduler(ComposerScheduler):
r"""Inverse square root LR decay with warmup and optional linear cooldown.
Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as:
.. math::
\alpha(t) = \begin{cases}
t / t_{warmup}, & \text{if } t < t_{warmup} \\
\alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\
\alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise}
\end{cases}
Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as:
.. math::
\tau_d = (t - t_{warmup} + t_{scale}) / {t_scale}
:math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`,
and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as:
.. math::
\tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown}
Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale,
:math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler,
:math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time,
and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to.
Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup.
Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning,
``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer.
Args:
t_warmup (str | Time): The warmup time.
t_scale (str | Time): The time scale.
t_cooldown (str | Time): The cooldown time.
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
alpha_f_decay (float): The learning rate multiplier to decay inverse square root decay to. Default = ``0.0``.
alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``.
"""

def __init__(self,
t_warmup: Union[str, Time],
t_scale: Union[str, Time],
t_cooldown: Union[str, Time],
t_max: Union[str, Time] = '1dur',
alpha_f_decay: float = 0.0,
alpha_f_cooldown: float = 0.0) -> None:
if alpha_f_decay < alpha_f_cooldown:
raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. '
f'Current: alpha_f_decay={alpha_f_decay}, '
f'alpha_f_cooldown={alpha_f_cooldown}.'))
_raise_if_units_dur(t_warmup, 't_warmup')
_raise_if_units_dur(t_scale, 't_scale')
_raise_if_units_dur(t_cooldown, 't_cooldown')
self.t_warmup = t_warmup
self.t_scale = t_scale
self.t_cooldown = t_cooldown
self.t_max = t_max
self.alpha_f_decay = alpha_f_decay
self.alpha_f_cooldown = alpha_f_cooldown
self.warmup_scheduler = LinearScheduler(alpha_i=0.0,
alpha_f=1.0,
t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0) -> float:
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
_raise_if_units_dont_match(self.t_warmup, state.max_duration,
't_warmup')
_raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale')
_raise_if_units_dont_match(self.t_cooldown, state.max_duration,
't_cooldown')

t_warmup = _convert_time(self.t_warmup, state)
if t_warmup.value == 0:
warnings.warn(
textwrap.dedent("""\
The warmup duration is 0. If warmup was specified as a fraction of the total
training duration, the warmup duration is calculated in the
same unit as the trainer's max_duration parameter."""))

if state.timestamp < t_warmup:
return self.warmup_scheduler(state)

t_scale = _convert_time(self.t_scale, state, ssr=ssr)
t_cooldown = _convert_time(self.t_cooldown, state, ssr=ssr)
t_max = _convert_time(self.t_max, state, ssr=ssr)
current_time = state.timestamp.get(t_scale.unit)

t_shift = t_scale - t_warmup
# t_cooldown_start is max of t_warmup, t_max - t_cooldown
t_cooldown_start = t_max - t_cooldown
if t_cooldown_start < t_warmup:
t_cooldown_start = t_warmup

if state.timestamp < t_cooldown_start:
# Rescale LR by a coefficient equal to the inverse square root of the time
# elapsed after warmup, rescaled by the time scale, such that, at
# infinite time, the LR decays to alpha_f_decay.
coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5
current_factor = (self.alpha_f_decay + coeff *
(1.0 - self.alpha_f_decay))
return current_factor

else:
coeff = 1 / ((t_cooldown_start + t_shift) / t_scale).value**0.5
alpha_i = self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay)

if t_cooldown.value == 0:
return alpha_i

# Linearly decay the LR from its value at the step at which cooldown
# started to alpha_f_cooldown over t_cooldown time.
frac_of_cooldown = ((current_time - t_cooldown_start) /
t_cooldown).value
frac_of_cooldown = min(1.0, frac_of_cooldown)
current_factor = (alpha_i + frac_of_cooldown *
(self.alpha_f_cooldown - alpha_i))
return current_factor
3 changes: 3 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -158,6 +159,8 @@ def build_scheduler(name: str,
return ConstantWithWarmupScheduler(**scheduler_config)
elif name == 'cosine_with_warmup':
return CosineAnnealingWithWarmupScheduler(**scheduler_config)
elif name == 'inv_sqrt_with_warmup':
return InverseSquareRootWithWarmupScheduler(**scheduler_config)
elif name == 'linear_decay_with_warmup':
return LinearWithWarmupScheduler(**scheduler_config)
else:
Expand Down
113 changes: 113 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List

import pytest
import torch
from composer.core import State, Time, TimeUnit
from composer.devices import DeviceCPU, DeviceGPU
from composer.optim.scheduler import ComposerScheduler

from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler

_MAX_DURATION = '100ba'
_STEPS_PER_EPOCH = 100


@pytest.fixture
def dummy_schedulers_state(request: pytest.FixtureRequest):
device = None
for item in request.session.items:
device = DeviceCPU(
) if item.get_closest_marker('gpu') is None else DeviceGPU()
break
assert device != None
state = State(
model=torch.nn.Linear(5, 5),
run_name='run_name',
device=device,
rank_zero_seed=17,
max_duration=_MAX_DURATION,
)
state.set_dataloader([None] * _STEPS_PER_EPOCH, 'train')
return state


@pytest.mark.parametrize('scheduler,ssr,test_times,expected_lrs', [
pytest.param(
InverseSquareRootWithWarmupScheduler(t_warmup='10ba',
t_scale='10ba',
t_cooldown='0ba',
alpha_f_decay=0,
alpha_f_cooldown=0), 1.0,
['0ba', '5ba', '10ba', '40ba', '90ba', '100ba'],
[0.0, 0.5, 1.0, 0.5, 0.33333, 0.31623]),
pytest.param(
InverseSquareRootWithWarmupScheduler(t_warmup='20ba',
t_scale='2ba',
t_cooldown='10ba',
alpha_f_decay=0.4,
alpha_f_cooldown=0.1), 1.0,
['0ba', '10ba', '20ba', '36ba', '90ba', '95ba', '100ba'],
[0.0, 0.5, 1.0, 0.6, 0.5, 0.3, 0.1]),
])
def test_scheduler_init(scheduler: ComposerScheduler, ssr: float,
test_times: List[str], expected_lrs: List[float],
dummy_schedulers_state: State):

state = dummy_schedulers_state
assert state.dataloader_len is not None
assert state.max_duration is not None
state.max_duration = Time(value=int(state.max_duration.value * ssr),
unit=state.max_duration.unit)
for test_time, expected_lr in zip(test_times, expected_lrs):
parsed_time = Time.from_timestring(test_time)
assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH]
state.timestamp = state.timestamp.copy(
batch=parsed_time,
epoch=Time(
int(parsed_time) // int(state.dataloader_len), TimeUnit.EPOCH),
)
lr = scheduler(state, ssr)
assert lr == pytest.approx(expected_lr, abs=1e-3)


@pytest.mark.parametrize('state_unit,warmup_unit,scale_unit,cooldown_unit', [
['ep', 'ba', 'ba', 'ba'],
['ba', 'ep', 'ep', 'ep'],
['ep', 'ep', 'ba', 'ep'],
])
def test_scheduler_units_match_error(state_unit: str, warmup_unit: str,
scale_unit: str, cooldown_unit: str,
dummy_schedulers_state: State):

state = dummy_schedulers_state
state.max_duration = f'1{state_unit}'
scheduler = InverseSquareRootWithWarmupScheduler(
t_warmup=f'10{warmup_unit}',
t_scale=f'10{scale_unit}',
t_cooldown=f'10{cooldown_unit}')
with pytest.raises(ValueError, match='does not match'):
_ = scheduler(state, 1.0)


@pytest.mark.parametrize('warmup_unit,scale_unit,cooldown_unit', [
['dur', 'ba', 'ba'],
['ba', 'dur', 'ba'],
['ba', 'ba', 'dur'],
])
def test_unit_dur_error(warmup_unit: str, scale_unit: str, cooldown_unit: str):
with pytest.raises(ValueError, match='cannot be in units of "dur".'):
_ = InverseSquareRootWithWarmupScheduler(t_warmup=f'1{warmup_unit}',
t_scale=f'1{scale_unit}',
t_cooldown=f'1{cooldown_unit}')


def test_alpha_f_error():
with pytest.raises(ValueError, match='alpha_f_decay >= alpha_f_cooldown.'):
_ = InverseSquareRootWithWarmupScheduler(t_warmup='10ba',
t_scale='10ba',
t_cooldown='10ba',
alpha_f_decay=0.0,
alpha_f_cooldown=0.1)

0 comments on commit 6c98276

Please sign in to comment.