diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 758484107b7..646ac37d341 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1038,7 +1038,12 @@ def check_model_type(self, supported_models: Union[List[str], dict]): else: supported_models_names.append(model.__name__) supported_models = supported_models_names - if self.model.__class__.__name__ not in supported_models: + if "Peft" in self.model.__class__.__name__ and hasattr(self.model, "base_model"): + # Peft models wrap a base model class, so let's look at the base class instead in that case + class_name = self.model.base_model.model.__class__.__name__ + else: + class_name = self.model.__class__.__name__ + if class_name not in supported_models: logger.error( f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are" f" {supported_models}."