diff --git a/README.md b/README.md index fc300d6051..717c99bf0c 100644 --- a/README.md +++ b/README.md @@ -675,7 +675,8 @@ gradient_accumulation_steps: 1 micro_batch_size: 2 eval_batch_size: num_epochs: 4 -warmup_steps: 100 +warmup_steps: 100 # cannot use with warmup_ratio +warmup_ratio: 0.05 # cannot use with warmup_steps learning_rate: 0.00003 lr_quadratic_warmup: logging_steps: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6b78f1f1ab..62e527beb0 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -461,11 +461,14 @@ def _get_trainer_cls(self): return AxolotlTrainer def build(self, total_num_steps): - warmup_steps = ( - self.cfg.warmup_steps - if self.cfg.warmup_steps is not None - else min(int(0.03 * total_num_steps), 100) - ) + warmup_steps = None + if self.cfg.warmup_steps is not None: + warmup_steps = self.cfg.warmup_steps + elif self.cfg.warmup_ratio is not None: + warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) + else: + warmup_steps = min(int(0.03 * total_num_steps), 100) + logging_steps = ( self.cfg.logging_steps if self.cfg.logging_steps is not None diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index d2db92a633..c41e059cde 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -372,6 +372,9 @@ def validate_config(cfg): if cfg.rope_scaling: LOG.warning("`rope_scaling` should now be be a key under `model_config`") + if cfg.warmup_steps and cfg.warmup_ratio: + raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index dbd030da33..5a4ef427ba 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -649,3 +649,33 @@ def test_load_in_x_bit_without_adapter(self): ) validate_config(cfg) + + def test_warmup_step_no_conflict(self): + cfg = DictDefault( + { + "warmup_steps": 10, + "warmup_ratio": 0.1, + } + ) + + with pytest.raises( + ValueError, + match=r".*warmup_steps and warmup_ratio are mutually exclusive*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "warmup_steps": 10, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "warmup_ratio": 0.1, + } + ) + + validate_config(cfg)