diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py new file mode 100644 index 0000000000..c29f73739e --- /dev/null +++ b/llmfoundry/optim/scheduler.py @@ -0,0 +1,153 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Experimental learning rate schedulers used for training LLMs.""" + +import textwrap +import warnings +from typing import Union + +from composer.core import State, Time, TimeUnit +from composer.optim import ComposerScheduler, LinearScheduler +from composer.optim.scheduler import _convert_time + +__all__ = ['InverseSquareRootWithWarmupScheduler'] + + +def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], + name: str) -> None: + if isinstance(time, str): + time = Time.from_timestring(time) + if isinstance(t_max, str): + t_max = Time.from_timestring(t_max) + if time.unit != t_max.unit: + raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') + + +def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: + if isinstance(time, str): + time = Time.from_timestring(time) + if time.unit == TimeUnit('dur'): + raise ValueError(f'{name} cannot be in units of "dur".') + + +class InverseSquareRootWithWarmupScheduler(ComposerScheduler): + r"""Inverse square root LR decay with warmup and optional linear cooldown. + + Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as: + + .. math:: + \alpha(t) = \begin{cases} + t / t_{warmup}, & \text{if } t < t_{warmup} \\ + \alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\ + \alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise} + \end{cases} + + Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as: + + .. math:: + \tau_d = (t - t_{warmup} + t_{scale}) / {t_scale} + + :math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`, + and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as: + + .. math:: + \tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown} + + Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale, + :math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler, + :math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time, + and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to. + + Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup. + + Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning, + ``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer. + + Args: + t_warmup (str | Time): The warmup time. + t_scale (str | Time): The time scale. + t_cooldown (str | Time): The cooldown time. + t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``. + alpha_f_decay (float): The learning rate multiplier to decay inverse square root decay to. Default = ``0.0``. + alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``. + """ + + def __init__(self, + t_warmup: Union[str, Time], + t_scale: Union[str, Time], + t_cooldown: Union[str, Time], + t_max: Union[str, Time] = '1dur', + alpha_f_decay: float = 0.0, + alpha_f_cooldown: float = 0.0) -> None: + if alpha_f_decay < alpha_f_cooldown: + raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. ' + f'Current: alpha_f_decay={alpha_f_decay}, ' + f'alpha_f_cooldown={alpha_f_cooldown}.')) + _raise_if_units_dur(t_warmup, 't_warmup') + _raise_if_units_dur(t_scale, 't_scale') + _raise_if_units_dur(t_cooldown, 't_cooldown') + self.t_warmup = t_warmup + self.t_scale = t_scale + self.t_cooldown = t_cooldown + self.t_max = t_max + self.alpha_f_decay = alpha_f_decay + self.alpha_f_cooldown = alpha_f_cooldown + self.warmup_scheduler = LinearScheduler(alpha_i=0.0, + alpha_f=1.0, + t_max=t_warmup) + + def __call__(self, state: State, ssr: float = 1.0) -> float: + assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' + _raise_if_units_dont_match(self.t_warmup, state.max_duration, + 't_warmup') + _raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale') + _raise_if_units_dont_match(self.t_cooldown, state.max_duration, + 't_cooldown') + + t_warmup = _convert_time(self.t_warmup, state) + if t_warmup.value == 0: + warnings.warn( + textwrap.dedent("""\ + The warmup duration is 0. If warmup was specified as a fraction of the total + training duration, the warmup duration is calculated in the + same unit as the trainer's max_duration parameter.""")) + + if state.timestamp < t_warmup: + return self.warmup_scheduler(state) + + t_scale = _convert_time(self.t_scale, state, ssr=ssr) + t_cooldown = _convert_time(self.t_cooldown, state, ssr=ssr) + t_max = _convert_time(self.t_max, state, ssr=ssr) + current_time = state.timestamp.get(t_scale.unit) + + t_shift = t_scale - t_warmup + # t_cooldown_start is max of t_warmup, t_max - t_cooldown + t_cooldown_start = t_max - t_cooldown + if t_cooldown_start < t_warmup: + t_cooldown_start = t_warmup + + if state.timestamp < t_cooldown_start: + # Rescale LR by a coefficient equal to the inverse square root of the time + # elapsed after warmup, rescaled by the time scale, such that, at + # infinite time, the LR decays to alpha_f_decay. + coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5 + current_factor = (self.alpha_f_decay + coeff * + (1.0 - self.alpha_f_decay)) + return current_factor + + else: + coeff = 1 / ((t_cooldown_start + t_shift) / t_scale).value**0.5 + alpha_i = self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay) + + if t_cooldown.value == 0: + return alpha_i + + # Linearly decay the LR from its value at the step at which cooldown + # started to alpha_f_cooldown over t_cooldown time. + frac_of_cooldown = ((current_time - t_cooldown_start) / + t_cooldown).value + frac_of_cooldown = min(1.0, frac_of_cooldown) + current_factor = (alpha_i + frac_of_cooldown * + (self.alpha_f_cooldown - alpha_i)) + return current_factor diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index ef4b411c54..c151ba38b7 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -32,6 +32,7 @@ ScheduledGarbageCollector) from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) +from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper log = logging.getLogger(__name__) @@ -158,6 +159,8 @@ def build_scheduler(name: str, return ConstantWithWarmupScheduler(**scheduler_config) elif name == 'cosine_with_warmup': return CosineAnnealingWithWarmupScheduler(**scheduler_config) + elif name == 'inv_sqrt_with_warmup': + return InverseSquareRootWithWarmupScheduler(**scheduler_config) elif name == 'linear_decay_with_warmup': return LinearWithWarmupScheduler(**scheduler_config) else: diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000000..5b9d45a141 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,113 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import pytest +import torch +from composer.core import State, Time, TimeUnit +from composer.devices import DeviceCPU, DeviceGPU +from composer.optim.scheduler import ComposerScheduler + +from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler + +_MAX_DURATION = '100ba' +_STEPS_PER_EPOCH = 100 + + +@pytest.fixture +def dummy_schedulers_state(request: pytest.FixtureRequest): + device = None + for item in request.session.items: + device = DeviceCPU( + ) if item.get_closest_marker('gpu') is None else DeviceGPU() + break + assert device != None + state = State( + model=torch.nn.Linear(5, 5), + run_name='run_name', + device=device, + rank_zero_seed=17, + max_duration=_MAX_DURATION, + ) + state.set_dataloader([None] * _STEPS_PER_EPOCH, 'train') + return state + + +@pytest.mark.parametrize('scheduler,ssr,test_times,expected_lrs', [ + pytest.param( + InverseSquareRootWithWarmupScheduler(t_warmup='10ba', + t_scale='10ba', + t_cooldown='0ba', + alpha_f_decay=0, + alpha_f_cooldown=0), 1.0, + ['0ba', '5ba', '10ba', '40ba', '90ba', '100ba'], + [0.0, 0.5, 1.0, 0.5, 0.33333, 0.31623]), + pytest.param( + InverseSquareRootWithWarmupScheduler(t_warmup='20ba', + t_scale='2ba', + t_cooldown='10ba', + alpha_f_decay=0.4, + alpha_f_cooldown=0.1), 1.0, + ['0ba', '10ba', '20ba', '36ba', '90ba', '95ba', '100ba'], + [0.0, 0.5, 1.0, 0.6, 0.5, 0.3, 0.1]), +]) +def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, + test_times: List[str], expected_lrs: List[float], + dummy_schedulers_state: State): + + state = dummy_schedulers_state + assert state.dataloader_len is not None + assert state.max_duration is not None + state.max_duration = Time(value=int(state.max_duration.value * ssr), + unit=state.max_duration.unit) + for test_time, expected_lr in zip(test_times, expected_lrs): + parsed_time = Time.from_timestring(test_time) + assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] + state.timestamp = state.timestamp.copy( + batch=parsed_time, + epoch=Time( + int(parsed_time) // int(state.dataloader_len), TimeUnit.EPOCH), + ) + lr = scheduler(state, ssr) + assert lr == pytest.approx(expected_lr, abs=1e-3) + + +@pytest.mark.parametrize('state_unit,warmup_unit,scale_unit,cooldown_unit', [ + ['ep', 'ba', 'ba', 'ba'], + ['ba', 'ep', 'ep', 'ep'], + ['ep', 'ep', 'ba', 'ep'], +]) +def test_scheduler_units_match_error(state_unit: str, warmup_unit: str, + scale_unit: str, cooldown_unit: str, + dummy_schedulers_state: State): + + state = dummy_schedulers_state + state.max_duration = f'1{state_unit}' + scheduler = InverseSquareRootWithWarmupScheduler( + t_warmup=f'10{warmup_unit}', + t_scale=f'10{scale_unit}', + t_cooldown=f'10{cooldown_unit}') + with pytest.raises(ValueError, match='does not match'): + _ = scheduler(state, 1.0) + + +@pytest.mark.parametrize('warmup_unit,scale_unit,cooldown_unit', [ + ['dur', 'ba', 'ba'], + ['ba', 'dur', 'ba'], + ['ba', 'ba', 'dur'], +]) +def test_unit_dur_error(warmup_unit: str, scale_unit: str, cooldown_unit: str): + with pytest.raises(ValueError, match='cannot be in units of "dur".'): + _ = InverseSquareRootWithWarmupScheduler(t_warmup=f'1{warmup_unit}', + t_scale=f'1{scale_unit}', + t_cooldown=f'1{cooldown_unit}') + + +def test_alpha_f_error(): + with pytest.raises(ValueError, match='alpha_f_decay >= alpha_f_cooldown.'): + _ = InverseSquareRootWithWarmupScheduler(t_warmup='10ba', + t_scale='10ba', + t_cooldown='10ba', + alpha_f_decay=0.0, + alpha_f_cooldown=0.1)