diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py index 9276d7525a652b..985d68834f70b9 100644 --- a/src/transformers/models/colpali/modular_colpali.py +++ b/src/transformers/models/colpali/modular_colpali.py @@ -367,6 +367,8 @@ def score_retrieval( query_embeddings: Union[torch.Tensor, List[torch.Tensor]], passage_embeddings: Union[torch.Tensor, List[torch.Tensor]], batch_size: int = 128, + output_dtype: Optional[torch.dtype] = torch.float32, + output_device: Union[torch.device, str] = "cpu", ) -> torch.Tensor: """ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector @@ -377,10 +379,13 @@ def score_retrieval( query_embeddings (`List[torch.Tensor]`): List of query embeddings. passage_embeddings (`List[torch.Tensor]`): List of passage embeddings. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. + output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. + If `None`, the dtype of the input embeddings is used. + output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. Returns: - `torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores - (device=cpu, dtype=float32). + `torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores. The score + tensor is saved on the "cpu" device. """ if len(query_embeddings) == 0: @@ -391,6 +396,12 @@ def score_retrieval( if query_embeddings[0].device != passage_embeddings[0].device: raise ValueError("Queries and passages must be on the same device") + if query_embeddings[0].dtype != passage_embeddings[0].dtype: + raise ValueError("Queries and passages must have the same dtype") + + if output_dtype is None: + output_dtype = query_embeddings[0].dtype + scores: List[torch.Tensor] = [] for i in range(0, len(query_embeddings), batch_size): @@ -405,11 +416,9 @@ def score_retrieval( batch_scores.append( torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) ) - scores.append(torch.cat(batch_scores, dim=1).cpu()) - - scores = torch.cat(scores, dim=0).to(torch.float32) + scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) - return scores + return torch.cat(scores, dim=0) def get_n_patches( self, diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py index f77e6932f0e54c..7432a147dcfbca 100644 --- a/src/transformers/models/colpali/processing_colpali.py +++ b/src/transformers/models/colpali/processing_colpali.py @@ -20,7 +20,7 @@ # limitations under the License. -from typing import ClassVar, List, Tuple, Union +from typing import ClassVar, List, Optional, Tuple, Union import torch @@ -372,6 +372,8 @@ def score_retrieval( query_embeddings: Union[torch.Tensor, List[torch.Tensor]], passage_embeddings: Union[torch.Tensor, List[torch.Tensor]], batch_size: int = 128, + output_dtype: Optional[torch.dtype] = torch.float32, + output_device: Union[torch.device, str] = "cpu", ) -> torch.Tensor: """ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector @@ -382,10 +384,13 @@ def score_retrieval( query_embeddings (`List[torch.Tensor]`): List of query embeddings. passage_embeddings (`List[torch.Tensor]`): List of passage embeddings. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. + output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. + If `None`, the dtype of the input embeddings is used. + output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. Returns: - `torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores - (device=cpu, dtype=float32). + `torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores. The score + tensor is saved on the "cpu" device. """ if len(query_embeddings) == 0: @@ -396,6 +401,12 @@ def score_retrieval( if query_embeddings[0].device != passage_embeddings[0].device: raise ValueError("Queries and passages must be on the same device") + if query_embeddings[0].dtype != passage_embeddings[0].dtype: + raise ValueError("Queries and passages must have the same dtype") + + if output_dtype is None: + output_dtype = query_embeddings[0].dtype + scores: List[torch.Tensor] = [] for i in range(0, len(query_embeddings), batch_size): @@ -410,11 +421,9 @@ def score_retrieval( batch_scores.append( torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) ) - scores.append(torch.cat(batch_scores, dim=1).cpu()) - - scores = torch.cat(scores, dim=0).to(torch.float32) + scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) - return scores + return torch.cat(scores, dim=0) def get_n_patches( self,