-
Notifications
You must be signed in to change notification settings - Fork 534
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Inverse Square Root LR Schedule (#657)
* Implement inverse square root with warmup scheduler v0 * inverse square root LR Schedule * scheduler * unit tests * Update llmfoundry/optim/scheduler.py Co-authored-by: Brian <[email protected]> * Update llmfoundry/optim/scheduler.py Co-authored-by: Brian <[email protected]> * fixes for PR conversations * format * fix type hint --------- Co-authored-by: Brian <[email protected]> Co-authored-by: cody <[email protected]>
- Loading branch information
1 parent
bdac4c7
commit 6c98276
Showing
3 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Experimental learning rate schedulers used for training LLMs.""" | ||
|
||
import textwrap | ||
import warnings | ||
from typing import Union | ||
|
||
from composer.core import State, Time, TimeUnit | ||
from composer.optim import ComposerScheduler, LinearScheduler | ||
from composer.optim.scheduler import _convert_time | ||
|
||
__all__ = ['InverseSquareRootWithWarmupScheduler'] | ||
|
||
|
||
def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], | ||
name: str) -> None: | ||
if isinstance(time, str): | ||
time = Time.from_timestring(time) | ||
if isinstance(t_max, str): | ||
t_max = Time.from_timestring(t_max) | ||
if time.unit != t_max.unit: | ||
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') | ||
|
||
|
||
def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: | ||
if isinstance(time, str): | ||
time = Time.from_timestring(time) | ||
if time.unit == TimeUnit('dur'): | ||
raise ValueError(f'{name} cannot be in units of "dur".') | ||
|
||
|
||
class InverseSquareRootWithWarmupScheduler(ComposerScheduler): | ||
r"""Inverse square root LR decay with warmup and optional linear cooldown. | ||
Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as: | ||
.. math:: | ||
\alpha(t) = \begin{cases} | ||
t / t_{warmup}, & \text{if } t < t_{warmup} \\ | ||
\alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\ | ||
\alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise} | ||
\end{cases} | ||
Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as: | ||
.. math:: | ||
\tau_d = (t - t_{warmup} + t_{scale}) / {t_scale} | ||
:math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`, | ||
and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as: | ||
.. math:: | ||
\tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown} | ||
Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale, | ||
:math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler, | ||
:math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time, | ||
and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to. | ||
Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup. | ||
Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning, | ||
``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer. | ||
Args: | ||
t_warmup (str | Time): The warmup time. | ||
t_scale (str | Time): The time scale. | ||
t_cooldown (str | Time): The cooldown time. | ||
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``. | ||
alpha_f_decay (float): The learning rate multiplier to decay inverse square root decay to. Default = ``0.0``. | ||
alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``. | ||
""" | ||
|
||
def __init__(self, | ||
t_warmup: Union[str, Time], | ||
t_scale: Union[str, Time], | ||
t_cooldown: Union[str, Time], | ||
t_max: Union[str, Time] = '1dur', | ||
alpha_f_decay: float = 0.0, | ||
alpha_f_cooldown: float = 0.0) -> None: | ||
if alpha_f_decay < alpha_f_cooldown: | ||
raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. ' | ||
f'Current: alpha_f_decay={alpha_f_decay}, ' | ||
f'alpha_f_cooldown={alpha_f_cooldown}.')) | ||
_raise_if_units_dur(t_warmup, 't_warmup') | ||
_raise_if_units_dur(t_scale, 't_scale') | ||
_raise_if_units_dur(t_cooldown, 't_cooldown') | ||
self.t_warmup = t_warmup | ||
self.t_scale = t_scale | ||
self.t_cooldown = t_cooldown | ||
self.t_max = t_max | ||
self.alpha_f_decay = alpha_f_decay | ||
self.alpha_f_cooldown = alpha_f_cooldown | ||
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, | ||
alpha_f=1.0, | ||
t_max=t_warmup) | ||
|
||
def __call__(self, state: State, ssr: float = 1.0) -> float: | ||
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' | ||
_raise_if_units_dont_match(self.t_warmup, state.max_duration, | ||
't_warmup') | ||
_raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale') | ||
_raise_if_units_dont_match(self.t_cooldown, state.max_duration, | ||
't_cooldown') | ||
|
||
t_warmup = _convert_time(self.t_warmup, state) | ||
if t_warmup.value == 0: | ||
warnings.warn( | ||
textwrap.dedent("""\ | ||
The warmup duration is 0. If warmup was specified as a fraction of the total | ||
training duration, the warmup duration is calculated in the | ||
same unit as the trainer's max_duration parameter.""")) | ||
|
||
if state.timestamp < t_warmup: | ||
return self.warmup_scheduler(state) | ||
|
||
t_scale = _convert_time(self.t_scale, state, ssr=ssr) | ||
t_cooldown = _convert_time(self.t_cooldown, state, ssr=ssr) | ||
t_max = _convert_time(self.t_max, state, ssr=ssr) | ||
current_time = state.timestamp.get(t_scale.unit) | ||
|
||
t_shift = t_scale - t_warmup | ||
# t_cooldown_start is max of t_warmup, t_max - t_cooldown | ||
t_cooldown_start = t_max - t_cooldown | ||
if t_cooldown_start < t_warmup: | ||
t_cooldown_start = t_warmup | ||
|
||
if state.timestamp < t_cooldown_start: | ||
# Rescale LR by a coefficient equal to the inverse square root of the time | ||
# elapsed after warmup, rescaled by the time scale, such that, at | ||
# infinite time, the LR decays to alpha_f_decay. | ||
coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5 | ||
current_factor = (self.alpha_f_decay + coeff * | ||
(1.0 - self.alpha_f_decay)) | ||
return current_factor | ||
|
||
else: | ||
coeff = 1 / ((t_cooldown_start + t_shift) / t_scale).value**0.5 | ||
alpha_i = self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay) | ||
|
||
if t_cooldown.value == 0: | ||
return alpha_i | ||
|
||
# Linearly decay the LR from its value at the step at which cooldown | ||
# started to alpha_f_cooldown over t_cooldown time. | ||
frac_of_cooldown = ((current_time - t_cooldown_start) / | ||
t_cooldown).value | ||
frac_of_cooldown = min(1.0, frac_of_cooldown) | ||
current_factor = (alpha_i + frac_of_cooldown * | ||
(self.alpha_f_cooldown - alpha_i)) | ||
return current_factor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import List | ||
|
||
import pytest | ||
import torch | ||
from composer.core import State, Time, TimeUnit | ||
from composer.devices import DeviceCPU, DeviceGPU | ||
from composer.optim.scheduler import ComposerScheduler | ||
|
||
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler | ||
|
||
_MAX_DURATION = '100ba' | ||
_STEPS_PER_EPOCH = 100 | ||
|
||
|
||
@pytest.fixture | ||
def dummy_schedulers_state(request: pytest.FixtureRequest): | ||
device = None | ||
for item in request.session.items: | ||
device = DeviceCPU( | ||
) if item.get_closest_marker('gpu') is None else DeviceGPU() | ||
break | ||
assert device != None | ||
state = State( | ||
model=torch.nn.Linear(5, 5), | ||
run_name='run_name', | ||
device=device, | ||
rank_zero_seed=17, | ||
max_duration=_MAX_DURATION, | ||
) | ||
state.set_dataloader([None] * _STEPS_PER_EPOCH, 'train') | ||
return state | ||
|
||
|
||
@pytest.mark.parametrize('scheduler,ssr,test_times,expected_lrs', [ | ||
pytest.param( | ||
InverseSquareRootWithWarmupScheduler(t_warmup='10ba', | ||
t_scale='10ba', | ||
t_cooldown='0ba', | ||
alpha_f_decay=0, | ||
alpha_f_cooldown=0), 1.0, | ||
['0ba', '5ba', '10ba', '40ba', '90ba', '100ba'], | ||
[0.0, 0.5, 1.0, 0.5, 0.33333, 0.31623]), | ||
pytest.param( | ||
InverseSquareRootWithWarmupScheduler(t_warmup='20ba', | ||
t_scale='2ba', | ||
t_cooldown='10ba', | ||
alpha_f_decay=0.4, | ||
alpha_f_cooldown=0.1), 1.0, | ||
['0ba', '10ba', '20ba', '36ba', '90ba', '95ba', '100ba'], | ||
[0.0, 0.5, 1.0, 0.6, 0.5, 0.3, 0.1]), | ||
]) | ||
def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, | ||
test_times: List[str], expected_lrs: List[float], | ||
dummy_schedulers_state: State): | ||
|
||
state = dummy_schedulers_state | ||
assert state.dataloader_len is not None | ||
assert state.max_duration is not None | ||
state.max_duration = Time(value=int(state.max_duration.value * ssr), | ||
unit=state.max_duration.unit) | ||
for test_time, expected_lr in zip(test_times, expected_lrs): | ||
parsed_time = Time.from_timestring(test_time) | ||
assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] | ||
state.timestamp = state.timestamp.copy( | ||
batch=parsed_time, | ||
epoch=Time( | ||
int(parsed_time) // int(state.dataloader_len), TimeUnit.EPOCH), | ||
) | ||
lr = scheduler(state, ssr) | ||
assert lr == pytest.approx(expected_lr, abs=1e-3) | ||
|
||
|
||
@pytest.mark.parametrize('state_unit,warmup_unit,scale_unit,cooldown_unit', [ | ||
['ep', 'ba', 'ba', 'ba'], | ||
['ba', 'ep', 'ep', 'ep'], | ||
['ep', 'ep', 'ba', 'ep'], | ||
]) | ||
def test_scheduler_units_match_error(state_unit: str, warmup_unit: str, | ||
scale_unit: str, cooldown_unit: str, | ||
dummy_schedulers_state: State): | ||
|
||
state = dummy_schedulers_state | ||
state.max_duration = f'1{state_unit}' | ||
scheduler = InverseSquareRootWithWarmupScheduler( | ||
t_warmup=f'10{warmup_unit}', | ||
t_scale=f'10{scale_unit}', | ||
t_cooldown=f'10{cooldown_unit}') | ||
with pytest.raises(ValueError, match='does not match'): | ||
_ = scheduler(state, 1.0) | ||
|
||
|
||
@pytest.mark.parametrize('warmup_unit,scale_unit,cooldown_unit', [ | ||
['dur', 'ba', 'ba'], | ||
['ba', 'dur', 'ba'], | ||
['ba', 'ba', 'dur'], | ||
]) | ||
def test_unit_dur_error(warmup_unit: str, scale_unit: str, cooldown_unit: str): | ||
with pytest.raises(ValueError, match='cannot be in units of "dur".'): | ||
_ = InverseSquareRootWithWarmupScheduler(t_warmup=f'1{warmup_unit}', | ||
t_scale=f'1{scale_unit}', | ||
t_cooldown=f'1{cooldown_unit}') | ||
|
||
|
||
def test_alpha_f_error(): | ||
with pytest.raises(ValueError, match='alpha_f_decay >= alpha_f_cooldown.'): | ||
_ = InverseSquareRootWithWarmupScheduler(t_warmup='10ba', | ||
t_scale='10ba', | ||
t_cooldown='10ba', | ||
alpha_f_decay=0.0, | ||
alpha_f_cooldown=0.1) |