diff --git a/optimum/tpu/modeling.py b/optimum/tpu/modeling.py index e651e260..2af1ac21 100644 --- a/optimum/tpu/modeling.py +++ b/optimum/tpu/modeling.py @@ -24,7 +24,8 @@ from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM def config_name_to_class(pretrained_model_name_or_path: str): - if "gemma" in pretrained_model_name_or_path: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + if config.model_type == "gemma": return TpuGemmaForCausalLM return BaseAutoModelForCausalLM