Skip to content

Commit

Permalink
Feat: Add warmup_ratio (#893)
Browse files Browse the repository at this point in the history
* Feat: Add warmup_ratio

* fix: update readme with more details on conflict
  • Loading branch information
NanoCode012 authored Nov 25, 2023
1 parent 9fc29e0 commit fb12895
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size:
num_epochs: 4
warmup_steps: 100
warmup_steps: 100 # cannot use with warmup_ratio
warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
Expand Down
13 changes: 8 additions & 5 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,14 @@ def _get_trainer_cls(self):
return AxolotlTrainer

def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
warmup_steps = None
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)

logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")

if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
30 changes: 30 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,33 @@ def test_load_in_x_bit_without_adapter(self):
)

validate_config(cfg)

def test_warmup_step_no_conflict(self):
cfg = DictDefault(
{
"warmup_steps": 10,
"warmup_ratio": 0.1,
}
)

with pytest.raises(
ValueError,
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
):
validate_config(cfg)

cfg = DictDefault(
{
"warmup_steps": 10,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"warmup_ratio": 0.1,
}
)

validate_config(cfg)

0 comments on commit fb12895

Please sign in to comment.