Skip to content

Commit

Permalink
feat: add option to choose score_retrieval's output dtype and device
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Nov 2, 2024
1 parent 5daf5e6 commit 11f2621
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
21 changes: 15 additions & 6 deletions src/transformers/models/colpali/modular_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def score_retrieval(
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
batch_size: int = 128,
output_dtype: Optional[torch.dtype] = torch.float32,
output_device: Union[torch.device, str] = "cpu",
) -> torch.Tensor:
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
Expand All @@ -377,10 +379,13 @@ def score_retrieval(
query_embeddings (`List[torch.Tensor]`): List of query embeddings.
passage_embeddings (`List[torch.Tensor]`): List of passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
If `None`, the dtype of the input embeddings is used.
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
Returns:
`torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores
(device=cpu, dtype=float32).
`torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores. The score
tensor is saved on the "cpu" device.
"""

if len(query_embeddings) == 0:
Expand All @@ -391,6 +396,12 @@ def score_retrieval(
if query_embeddings[0].device != passage_embeddings[0].device:
raise ValueError("Queries and passages must be on the same device")

if query_embeddings[0].dtype != passage_embeddings[0].dtype:
raise ValueError("Queries and passages must have the same dtype")

if output_dtype is None:
output_dtype = query_embeddings[0].dtype

scores: List[torch.Tensor] = []

for i in range(0, len(query_embeddings), batch_size):
Expand All @@ -405,11 +416,9 @@ def score_retrieval(
batch_scores.append(
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
)
scores.append(torch.cat(batch_scores, dim=1).cpu())

scores = torch.cat(scores, dim=0).to(torch.float32)
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))

return scores
return torch.cat(scores, dim=0)

def get_n_patches(
self,
Expand Down
23 changes: 16 additions & 7 deletions src/transformers/models/colpali/processing_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# limitations under the License.


from typing import ClassVar, List, Tuple, Union
from typing import ClassVar, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -372,6 +372,8 @@ def score_retrieval(
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
batch_size: int = 128,
output_dtype: Optional[torch.dtype] = torch.float32,
output_device: Union[torch.device, str] = "cpu",
) -> torch.Tensor:
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
Expand All @@ -382,10 +384,13 @@ def score_retrieval(
query_embeddings (`List[torch.Tensor]`): List of query embeddings.
passage_embeddings (`List[torch.Tensor]`): List of passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
If `None`, the dtype of the input embeddings is used.
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
Returns:
`torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores
(device=cpu, dtype=float32).
`torch.Tensor`: A tensor of shape `(len(qs), len(ps))` containing the scores. The score
tensor is saved on the "cpu" device.
"""

if len(query_embeddings) == 0:
Expand All @@ -396,6 +401,12 @@ def score_retrieval(
if query_embeddings[0].device != passage_embeddings[0].device:
raise ValueError("Queries and passages must be on the same device")

if query_embeddings[0].dtype != passage_embeddings[0].dtype:
raise ValueError("Queries and passages must have the same dtype")

if output_dtype is None:
output_dtype = query_embeddings[0].dtype

scores: List[torch.Tensor] = []

for i in range(0, len(query_embeddings), batch_size):
Expand All @@ -410,11 +421,9 @@ def score_retrieval(
batch_scores.append(
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
)
scores.append(torch.cat(batch_scores, dim=1).cpu())

scores = torch.cat(scores, dim=0).to(torch.float32)
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))

return scores
return torch.cat(scores, dim=0)

def get_n_patches(
self,
Expand Down

0 comments on commit 11f2621

Please sign in to comment.