diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4b9c79d848..9f0795af76 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -371,7 +371,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model): + if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: