diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index f21c96d3fc..af7e494122 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -5,19 +5,70 @@ import warnings from typing import Union -from composer.core import State, Time, TimeUnit +from composer.core import State, Time from composer.optim import ComposerScheduler, LinearScheduler from composer.optim.scheduler import _convert_time +__all__ = ['InverseSquareRootWithWarmupScheduler'] -def _raise_if_unit_not_ba(time: Union[str, Time]) -> None: + +def _raise_if_units_dont_match(time: Union[str, Time], + t_max: Union[str, Time]) -> None: if isinstance(time, str): time = Time.from_timestring(time) - if time.unit != TimeUnit('ba'): - raise ValueError + if isinstance(t_max, str): + t_max = Time.from_timestring(t_max) + if time.unit != t_max.unit: + raise ValueError( + 'All time units must be the same as max_duration units.') class InverseSquareRootWithWarmupScheduler(ComposerScheduler): + r"""Inverse square root LR decay with warmup and optional linear cooldown. + + Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as: + + .. math:: + \alpha(t) = \begin{cases} + t / t_{warmup}, & \text{if } t < t_{warmup} \\ + \alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\ + \alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise} + \end{cases} + + Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as: + + .. math:: + \tau_d = (t - t_{warmup} + t_{scale}) / {t_scale} + + :math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`, + and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as: + + .. math:: + \tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown} + + Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale, + :math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler, + :math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time, + and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to. + + Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup. + + Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning, + ``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer. + + .. warning:: + By default, initial warmup time is **not** scaled according to any provided scale schedule ratio. + To change this behavior, set ``scale_warmup=True``. + + Args: + t_warmup (str | Time): Warmup time. + t_scale (str | Time): Time scale. + t_cooldown (str | Time): Cooldown time. + t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``. + alpha_f_decay (float): Learning rate multiplier to decay inverse square root decay to. Default = ``0.0``. + alpha_f_cooldown (float): Learning rate multiplier to decay linear cooldown to. Default = ``0.0``. + scale_warmup (bool): SSR also scales the warmup period. Default = ``False``. + """ def __init__(self, t_warmup: Union[str, Time], @@ -27,11 +78,8 @@ def __init__(self, alpha_f_decay: float = 0.0, alpha_f_cooldown: float = 0.0, scale_warmup: bool = False): - _raise_if_unit_not_ba(t_warmup) - _raise_if_unit_not_ba(t_scale) - _raise_if_unit_not_ba(t_cooldown) - if alpha_f_cooldown > alpha_f_decay: - raise ValueError + if alpha_f_decay < alpha_f_cooldown: + raise ValueError('Required: alpha_f_decay >= alpha_f_cooldown.') self.t_warmup = t_warmup self.t_scale = t_scale self.t_cooldown = t_cooldown @@ -45,7 +93,9 @@ def __init__(self, def __call__(self, state: State, ssr: float = 1.0) -> float: assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' - _raise_if_unit_not_ba(state.max_duration) + _raise_if_units_dont_match(self.t_warmup, state.max_duration) + _raise_if_units_dont_match(self.t_scale, state.max_duration) + _raise_if_units_dont_match(self.t_cooldown, state.max_duration) t_warmup = _convert_time(self.t_warmup, state) if t_warmup.value == 0: