diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 2f24c1199a131e..0323236ef5ac80 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -267,11 +267,11 @@ def _init_weights(self, module): import torch from PIL import Image - from transformers import ColPali, ColPaliProcessor + from transformers import ColPaliForRetrieval, ColPaliProcessor model_name = "vidore/colpali-v1.2-hf" - model = ColPali.from_pretrained( + model = ColPaliForRetrieval.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda:0", # or "mps" if on Apple Silicon diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py index 0f43849457a843..9276d7525a652b 100644 --- a/src/transformers/models/colpali/modular_colpali.py +++ b/src/transformers/models/colpali/modular_colpali.py @@ -548,11 +548,11 @@ class ColPaliForRetrievalOutput(ModelOutput): import torch from PIL import Image - from transformers import ColPali, ColPaliProcessor + from transformers import ColPaliForRetrieval, ColPaliProcessor model_name = "vidore/colpali-v1.2-hf" - model = ColPali.from_pretrained( + model = ColPaliForRetrieval.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda:0", # or "mps" if on Apple Silicon