Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] query_namespaces can handle single result #421

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import json
from typing import Union, List, Optional, Dict, Any
from typing import Union, List, Optional, Dict, Any, Literal

from pinecone.config import ConfigBuilder

Expand Down Expand Up @@ -511,6 +511,7 @@ def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
Expand Down Expand Up @@ -540,6 +541,7 @@ def query_namespaces(
combined_results = index.query_namespaces(
vector=query_vec,
namespaces=['ns1', 'ns2', 'ns3', 'ns4'],
metric="cosine",
top_k=10,
filter={'genre': {"$eq": "drama"}},
include_values=True,
Expand All @@ -554,6 +556,7 @@ def query_namespaces(
vector (List[float]): The query vector, must be the same length as the dimension of the index being queried.
namespaces (List[str]): The list of namespaces to query.
top_k (Optional[int], optional): The number of results you would like to request from each namespace. Defaults to 10.
metric (str): Must be one of 'cosine', 'euclidean', 'dotproduct'. This is needed in order to merge results across namespaces, since the interpretation of score depends on the index metric type.
filter (Optional[Dict[str, Union[str, float, int, bool, List, dict]]], optional): Pass an optional filter to filter results based on metadata. Defaults to None.
include_values (Optional[bool], optional): Boolean field indicating whether vector values should be included with results. Defaults to None.
include_metadata (Optional[bool], optional): Boolean field indicating whether vector metadata should be included with results. Defaults to None.
Expand All @@ -568,7 +571,7 @@ def query_namespaces(
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric)

target_namespaces = set(namespaces) # dedup namespaces
async_futures = [
Expand Down
62 changes: 23 additions & 39 deletions pinecone/data/query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional, Any, Dict
from typing import List, Tuple, Optional, Any, Dict, Literal
import json
import heapq
from pinecone.core.openapi.data.models import Usage
Expand Down Expand Up @@ -88,46 +88,38 @@ def __repr__(self):
)


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self):
super().__init__(
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
)
super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 2:
def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]):
if top_k < 1:
raise QueryResultsAggregatorInvalidTopKError(top_k)

if metric in ["dotproduct", "cosine"]:
self.is_bigger_better = True
elif metric in ["euclidean"]:
self.is_bigger_better = False
else:
raise ValueError(
f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'."
)

self.top_k = top_k
self.usage_read_units = 0
self.heap: List[Tuple[float, int, object, str]] = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False
self.final_results: Optional[QueryNamespacesResults] = None

def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
# the order of the scores in the results.
for i in range(1, len(matches)):
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
return False
return True

def _dotproduct_heap_item(self, match, ns):
def _bigger_better_heap_item(self, match, ns):
# This 4-tuple is used to ensure that the heap is sorted by score followed by
# insertion order. The insertion order is used to break any ties in the score.
return (match.get("score"), -self.insertion_counter, match, ns)

def _non_dotproduct_heap_item(self, match, ns):
def _smaller_better_heap_item(self, match, ns):
return (-match.get("score"), -self.insertion_counter, match, ns)

def _process_matches(self, matches, ns, heap_item_fn):
Expand All @@ -137,10 +129,10 @@ def _process_matches(self, matches, ns, heap_item_fn):
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
# Assume we have dotproduct scores sorted in descending order
if self.is_dotproduct and match["score"] < self.heap[0][0]:
if self.is_bigger_better and match["score"] < self.heap[0][0]:
# No further matches can improve the top-K heap
break
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
elif not self.is_bigger_better and match["score"] > -self.heap[0][0]:
# No further matches can improve the top-K heap
break
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
Expand All @@ -162,18 +154,10 @@ def add_results(self, results: Dict[str, Any]):
if len(matches) == 0:
return

if self.is_dotproduct is None:
if len(matches) == 1:
# This condition should match the second time we add results containing
# only one match. We need at least two matches in a single response in order
# to infer the similarity metric
raise QueryResultsAggregregatorNotEnoughResultsError()
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
self._process_matches(matches, ns, self._dotproduct_heap_item)
if self.is_bigger_better:
self._process_matches(matches, ns, self._bigger_better_heap_item)
else:
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
self._process_matches(matches, ns, self._smaller_better_heap_item)

def get_results(self) -> QueryNamespacesResults:
if self.read:
Expand Down
5 changes: 3 additions & 2 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast, Literal

from google.protobuf import json_format

Expand Down Expand Up @@ -409,6 +409,7 @@ def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
Expand All @@ -422,7 +423,7 @@ def query_namespaces(
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric)

target_namespaces = set(namespaces) # dedup namespaces
futures = [
Expand Down
69 changes: 35 additions & 34 deletions tests/integration/data/test_query_namespaces.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import pytest
from ..helpers import random_string, poll_stats_for_namespace
from pinecone.data.query_results_aggregator import (
QueryResultsAggregatorInvalidTopKError,
QueryResultsAggregregatorNotEnoughResultsError,
)

from pinecone import Vector


class TestQueryNamespacesRest:
def test_query_namespaces(self, idx):
def test_query_namespaces(self, idx, metric):
ns_prefix = random_string(5)
ns1 = f"{ns_prefix}-ns1"
ns2 = f"{ns_prefix}-ns2"
Expand Down Expand Up @@ -50,6 +46,7 @@ def test_query_namespaces(self, idx):
results = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "drama"}},
Expand Down Expand Up @@ -84,6 +81,7 @@ def test_query_namespaces(self, idx):
results2 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3, f"{ns_prefix}-nonexistent"],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "action"}},
Expand All @@ -98,6 +96,7 @@ def test_query_namespaces(self, idx):
results3 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -110,6 +109,7 @@ def test_query_namespaces(self, idx):
results4 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "comedy"}},
Expand All @@ -122,6 +122,7 @@ def test_query_namespaces(self, idx):
results5 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -137,6 +138,7 @@ def test_query_namespaces(self, idx):
f"{ns_prefix}-nonexistent2",
f"{ns_prefix}-nonexistent3",
],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "comedy"}},
Expand All @@ -145,22 +147,7 @@ def test_query_namespaces(self, idx):
assert len(results6.matches) == 0
assert results6.usage.read_units > 0

def test_invalid_top_k(self, idx):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need this test anymore since top_k of 1 is now valid.

with pytest.raises(QueryResultsAggregatorInvalidTopKError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=["ns1", "ns2", "ns3"],
include_values=True,
include_metadata=True,
filter={},
top_k=1,
)
assert (
str(e.value)
== "Invalid top_k value 1. To aggregate results from multiple queries the top_k must be at least 2."
)

def test_unmergeable_results(self, idx):
def test_single_result_per_namespace(self, idx):
ns_prefix = random_string(5)
ns1 = f"{ns_prefix}-ns1"
ns2 = f"{ns_prefix}-ns2"
Expand All @@ -183,26 +170,27 @@ def test_unmergeable_results(self, idx):
poll_stats_for_namespace(idx, namespace=ns1, expected_count=2)
poll_stats_for_namespace(idx, namespace=ns2, expected_count=2)

with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2],
include_values=True,
include_metadata=True,
filter={"key": {"$eq": 1}},
top_k=2,
)

assert (
str(e.value)
== "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
results = idx.query_namespaces(
vector=[0.1, 0.21],
namespaces=[ns1, ns2],
metric="cosine",
include_values=True,
include_metadata=True,
filter={"key": {"$eq": 1}},
top_k=2,
)
assert len(results.matches) == 2
assert results.matches[0].id == "id1"
assert results.matches[0].namespace == ns1
assert results.matches[1].id == "id5"
assert results.matches[1].namespace == ns2

def test_missing_namespaces(self, idx):
with pytest.raises(ValueError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[],
metric="cosine",
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -214,9 +202,22 @@ def test_missing_namespaces(self, idx):
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=None,
metric="cosine",
include_values=True,
include_metadata=True,
filter={},
top_k=2,
)
assert str(e.value) == "At least one namespace must be specified"

def test_missing_metric(self, idx):
with pytest.raises(TypeError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=["ns1"],
include_values=True,
include_metadata=True,
filter={},
top_k=2,
)
assert "query_namespaces() missing 1 required positional argument: 'metric'" in str(e.value)
Loading
Loading