diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1c0487ff8e..cb18380cb7 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -94,7 +94,7 @@ def validate_config(cfg): if not cfg.bf16 and not cfg.bfloat16: LOG.info("bf16 support detected, but not enabled for this configuration.") else: - if cfg.bf16 or cfg.bfloat16: + if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) diff --git a/tests/test_validation.py b/tests/test_validation.py index f250e5cb47..d7935c1a54 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -351,3 +351,26 @@ def test_packing(self): regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + + def test_merge_lora_no_bf16_fail(self): + """ + This is assumed to be run on a CPU machine, so bf16 is not supported. + """ + + cfg = DictDefault( + { + "bf16": True, + } + ) + + with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"): + validate_config(cfg) + + cfg = DictDefault( + { + "bf16": True, + "merge_lora": True, + } + ) + + validate_config(cfg)