Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PEFT models in pipelines #29517

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,12 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
if "Peft" in self.model.__class__.__name__ and hasattr(self.model, "base_model"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, checking if Peft is in the name seems like a flaky way to check this.

What I'd suggest is adding a more general is_peft_model utility under modeling_utils.py which is tested (and we check with @younesbelkada is correct) which can then be used everywhere. It should also check is_peft_available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you can do if the model is an isntance of Peftxxx class, however we do support already peft model inference: https://huggingface.co/docs/transformers/peft , users that pass a Peftxxx class can be considered as bad intent, as they only need to pass a valid path to an adapter model or first load a peft model using AutoModelForxxx and pass that to the pipeline. See for example:

def test_peft_pipeline(self):

# Peft models wrap a base model class, so let's look at the base class instead in that case
class_name = self.model.base_model.model.__class__.__name__
else:
class_name = self.model.__class__.__name__
if class_name not in supported_models:
logger.error(
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are"
f" {supported_models}."
Expand Down
Loading