diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 468bd6e7f8..30ba53ad25 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -153,7 +153,7 @@ def normalize_config(cfg): cfg.is_llama_derived_model = ( ( hasattr(model_config, "model_type") - and model_config.model_type == ["llama", "mllama_text_model"] + and model_config.model_type in ["llama", "mllama_text_model"] ) or cfg.is_llama_derived_model or "llama" in cfg.base_model.lower()