Skip to content

Commit

Permalink
Fix: Fail bf16 check when running on cpu during merge
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Sep 24, 2023
1 parent 67b9888 commit bd25f73
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def validate_config(cfg):
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if cfg.bf16 or cfg.bfloat16:
if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,26 @@ def test_packing(self):
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)

def test_merge_lora_no_bf16_fail(self):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.
"""

cfg = DictDefault(
{
"bf16": True,
}
)

with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
validate_config(cfg)

cfg = DictDefault(
{
"bf16": True,
"merge_lora": True,
}
)

validate_config(cfg)

0 comments on commit bd25f73

Please sign in to comment.