-
Notifications
You must be signed in to change notification settings - Fork 538
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into mlflow-log-model
- Loading branch information
Showing
19 changed files
with
1,173 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Experimental learning rate schedulers used for training LLMs.""" | ||
|
||
import textwrap | ||
import warnings | ||
from typing import Union | ||
|
||
from composer.core import State, Time, TimeUnit | ||
from composer.optim import ComposerScheduler, LinearScheduler | ||
from composer.optim.scheduler import _convert_time | ||
|
||
__all__ = ['InverseSquareRootWithWarmupScheduler'] | ||
|
||
|
||
def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], | ||
name: str) -> None: | ||
if isinstance(time, str): | ||
time = Time.from_timestring(time) | ||
if isinstance(t_max, str): | ||
t_max = Time.from_timestring(t_max) | ||
if time.unit != t_max.unit: | ||
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') | ||
|
||
|
||
def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: | ||
if isinstance(time, str): | ||
time = Time.from_timestring(time) | ||
if time.unit == TimeUnit('dur'): | ||
raise ValueError(f'{name} cannot be in units of "dur".') | ||
|
||
|
||
class InverseSquareRootWithWarmupScheduler(ComposerScheduler): | ||
r"""Inverse square root LR decay with warmup and optional linear cooldown. | ||
Specifically, the learning rate multiplier :math:`\alpha(t)` can be expressed as: | ||
.. math:: | ||
\alpha(t) = \begin{cases} | ||
t / t_{warmup}, & \text{if } t < t_{warmup} \\ | ||
\alpha_{f,decay} + \frac{1 - \alpha_{f,decay}}{\sqrt{\tau_d}}, & \text{if } t_{warmup} <= t < t_{max} - t_{cooldown} \\ | ||
\alpha_i + (alpha_{f,cooldown} - \alpha_i) \times \tau_c, & \text{otherwise} | ||
\end{cases} | ||
Given :math:`\tau_d`, the time elapsed during the inverse square root decay (normalized by :math:`t_scale`), as: | ||
.. math:: | ||
\tau_d = (t - t_{warmup} + t_{scale}) / {t_scale} | ||
:math:`\alpha_i` as the value of the learning rate multiplier when :math:`\tau_d` is evaluated at :math:`t = t_{max} - t_{cooldown}`, | ||
and :math:`\tau_c`, the fraction of linear cooldown time elapsed (clipped to the interval :math:`[0, 1]`), as: | ||
.. math:: | ||
\tau_c = (t - t_{max} + t_{cooldown}) / t_{cooldown} | ||
Where :math:`t_{warmup}` represents the warmup time, :math:`t_{scale}` represents the time scale, | ||
:math:`t_{cooldown}` represents the cooldown time, :math:`t_{max}` represents the duration of this scheduler, | ||
:math:`\alpha_{f,decay}` represents the learning rate multiplier that the inverse square root decays to at infinite time, | ||
and :math:`\alpha_{f,cooldown}` represents the learning rate multiplier that the linear cooldown decays to. | ||
Note, :math:`\alpha_{f,decay} >= \alpha_{f,cooldown}` to ensure that the learning rate is monotonically decreasing after warmup. | ||
Also note, ``t_warmup``, ``t_scale``, and ``t_cooldown`` cannot be specified in units of duration; since this schedule is designed for continual learning, | ||
``max_duration`` is expected to change. Instead, these parameters need to be specified in the same units as ``max_duration`` passed to the trainer. | ||
Args: | ||
t_warmup (str | Time): The warmup time. | ||
t_scale (str | Time): The time scale. | ||
t_cooldown (str | Time): The cooldown time. | ||
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``. | ||
alpha_f_decay (float): The learning rate multiplier to decay inverse square root decay to. Default = ``0.0``. | ||
alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``. | ||
""" | ||
|
||
def __init__(self, | ||
t_warmup: Union[str, Time], | ||
t_scale: Union[str, Time], | ||
t_cooldown: Union[str, Time], | ||
t_max: Union[str, Time] = '1dur', | ||
alpha_f_decay: float = 0.0, | ||
alpha_f_cooldown: float = 0.0) -> None: | ||
if alpha_f_decay < alpha_f_cooldown: | ||
raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. ' | ||
f'Current: alpha_f_decay={alpha_f_decay}, ' | ||
f'alpha_f_cooldown={alpha_f_cooldown}.')) | ||
_raise_if_units_dur(t_warmup, 't_warmup') | ||
_raise_if_units_dur(t_scale, 't_scale') | ||
_raise_if_units_dur(t_cooldown, 't_cooldown') | ||
self.t_warmup = t_warmup | ||
self.t_scale = t_scale | ||
self.t_cooldown = t_cooldown | ||
self.t_max = t_max | ||
self.alpha_f_decay = alpha_f_decay | ||
self.alpha_f_cooldown = alpha_f_cooldown | ||
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, | ||
alpha_f=1.0, | ||
t_max=t_warmup) | ||
|
||
def __call__(self, state: State, ssr: float = 1.0) -> float: | ||
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' | ||
_raise_if_units_dont_match(self.t_warmup, state.max_duration, | ||
't_warmup') | ||
_raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale') | ||
_raise_if_units_dont_match(self.t_cooldown, state.max_duration, | ||
't_cooldown') | ||
|
||
t_warmup = _convert_time(self.t_warmup, state) | ||
if t_warmup.value == 0: | ||
warnings.warn( | ||
textwrap.dedent("""\ | ||
The warmup duration is 0. If warmup was specified as a fraction of the total | ||
training duration, the warmup duration is calculated in the | ||
same unit as the trainer's max_duration parameter.""")) | ||
|
||
if state.timestamp < t_warmup: | ||
return self.warmup_scheduler(state) | ||
|
||
t_scale = _convert_time(self.t_scale, state, ssr=ssr) | ||
t_cooldown = _convert_time(self.t_cooldown, state, ssr=ssr) | ||
t_max = _convert_time(self.t_max, state, ssr=ssr) | ||
current_time = state.timestamp.get(t_scale.unit) | ||
|
||
t_shift = t_scale - t_warmup | ||
# t_cooldown_start is max of t_warmup, t_max - t_cooldown | ||
t_cooldown_start = t_max - t_cooldown | ||
if t_cooldown_start < t_warmup: | ||
t_cooldown_start = t_warmup | ||
|
||
if state.timestamp < t_cooldown_start: | ||
# Rescale LR by a coefficient equal to the inverse square root of the time | ||
# elapsed after warmup, rescaled by the time scale, such that, at | ||
# infinite time, the LR decays to alpha_f_decay. | ||
coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5 | ||
current_factor = (self.alpha_f_decay + coeff * | ||
(1.0 - self.alpha_f_decay)) | ||
return current_factor | ||
|
||
else: | ||
coeff = 1 / ((t_cooldown_start + t_shift) / t_scale).value**0.5 | ||
alpha_i = self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay) | ||
|
||
if t_cooldown.value == 0: | ||
return alpha_i | ||
|
||
# Linearly decay the LR from its value at the step at which cooldown | ||
# started to alpha_f_cooldown over t_cooldown time. | ||
frac_of_cooldown = ((current_time - t_cooldown_start) / | ||
t_cooldown).value | ||
frac_of_cooldown = min(1.0, frac_of_cooldown) | ||
current_factor = (alpha_i + frac_of_cooldown * | ||
(self.alpha_f_cooldown - alpha_i)) | ||
return current_factor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.