From 8430db22e29d613bc4151af830e7916c884b61ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=EC=9B=90?= Date: Tue, 13 Feb 2024 14:23:28 +0900 Subject: [PATCH] Scheduler implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (#1273) --- README.md | 1 + src/axolotl/core/trainer_builder.py | 20 +++++++ src/axolotl/utils/schedulers.py | 81 ++++++++++++++++++++++++++++- tests/test_schedulers.py | 52 ++++++++++++++++++ 4 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 tests/test_schedulers.py diff --git a/README.md b/README.md index cf38650649..5cb1df3246 100644 --- a/README.md +++ b/README.md @@ -813,6 +813,7 @@ early_stopping_patience: 3 lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler_kwargs: cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr +cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) # For one_cycle optim lr_div_factor: # Learning rate div factor diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fd7aeef535..e2667aea43 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -50,6 +50,7 @@ from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, + get_cosine_schedule_with_warmup_decay_constant, ) try: @@ -164,6 +165,12 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) class AxolotlTrainer(Trainer): @@ -221,6 +228,16 @@ def create_scheduler( num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init @@ -887,6 +904,9 @@ def build(self, total_num_steps): self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio + training_arguments_kwargs[ + "cosine_constant_lr_ratio" + ] = self.cfg.cosine_constant_lr_ratio training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index c49745c263..94387e5ab8 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -52,7 +52,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( *, num_warmup_steps: int, num_training_steps: int, - num_cycles: float + num_cycles: float, ): if current_step < num_warmup_steps: return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 @@ -107,7 +107,7 @@ def _get_cosine_schedule_with_min_lr_lambda( *, num_warmup_steps: int, num_training_steps: int, - min_lr_ratio: float + min_lr_ratio: float, ): # Warm up if current_step < num_warmup_steps: @@ -140,3 +140,80 @@ def get_cosine_schedule_with_min_lr( min_lr_ratio=min_lr_ratio, ) return LambdaLR(optimizer, lr_lambda) + + +def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + constant_lr_ratio: float, + min_lr_ratio: float, + num_cycles: float, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + num_constant_steps = int(num_training_steps * constant_lr_ratio) + current_step = min(current_step, num_constant_steps) + + progress = float(current_step - num_warmup_steps) / float( + max(1, num_constant_steps - num_warmup_steps) + ) + + return ( + max( + 0, + (1 - min_lr_ratio) + * 0.5 + * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + min_lr_ratio + ) + + +def get_cosine_schedule_with_warmup_decay_constant( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + constant_lr_ratio: float, + min_lr_ratio: float, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf) + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate + , after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + constant_lr_ratio: (`float`): + The ratio of num_training_steps to decrease by cosine function. + min_lr_ratio: (`float): + The ratio of maximum learning rate for cosine function to decay to minimum learning rate. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_decay_constant_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + constant_lr_ratio=constant_lr_ratio, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py new file mode 100644 index 0000000000..9402d7af7f --- /dev/null +++ b/tests/test_schedulers.py @@ -0,0 +1,52 @@ +""" +test module for the axolotl.utis.data module +""" +import unittest + +import torch +from torch.optim import SGD + +from axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant + + +class TestCosineConstantLr(unittest.TestCase): + """ + test class for encode pretraining and md5 helper + """ + + def setUp(self): + self.train_steps = 1000 + self.warmup_steps = 10 + self.min_lr_ratio = 0.1 + self.constant_lr_ratio = 0.8 + self._lr = 0.01 + self.optimizer = SGD([torch.tensor(1)], lr=self._lr) + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + self.optimizer, + num_warmup_steps=self.warmup_steps, + num_training_steps=self.train_steps, + min_lr_ratio=self.min_lr_ratio, + constant_lr_ratio=self.constant_lr_ratio, + ) + + def test_schedulers(self): + self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0) + for _ in range(self.warmup_steps): + self.lr_scheduler.step() + self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr) + constant_step = int(self.train_steps * self.constant_lr_ratio) + remaining_step = self.train_steps - constant_step + for _ in range(constant_step): + self.lr_scheduler.step() + self.assertEqual( + self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio + ) + for _ in range(remaining_step): + self.lr_scheduler.step() + self.assertEqual( + self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio + ) + + +if __name__ == "__main__": + unittest.main()