diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dfa648d499..6e520bd50e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -200,7 +200,11 @@ def load_model( ) # sample packing uses custom FA2 patch if cfg.flash_attention and not cfg.sample_packing: - if cfg.is_llama_derived_model or cfg.is_falcon_derived_model or cfg.is_mistral_derived_model: + if ( + cfg.is_llama_derived_model + or cfg.is_falcon_derived_model + or cfg.is_mistral_derived_model + ): model_kwargs["use_flash_attention_2"] = True try: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: