Skip to content

Commit

Permalink
Add type hints for metric kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Dec 6, 2024
1 parent 98a3f79 commit 6d6f63e
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 182 deletions.
7 changes: 5 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import json
from typing import Union, List, Optional, Dict, Any
from typing import Union, List, Optional, Dict, Any, Literal

from pinecone.config import ConfigBuilder

Expand Down Expand Up @@ -511,6 +511,7 @@ def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
Expand Down Expand Up @@ -540,6 +541,7 @@ def query_namespaces(
combined_results = index.query_namespaces(
vector=query_vec,
namespaces=['ns1', 'ns2', 'ns3', 'ns4'],
metric="cosine",
top_k=10,
filter={'genre': {"$eq": "drama"}},
include_values=True,
Expand All @@ -554,6 +556,7 @@ def query_namespaces(
vector (List[float]): The query vector, must be the same length as the dimension of the index being queried.
namespaces (List[str]): The list of namespaces to query.
top_k (Optional[int], optional): The number of results you would like to request from each namespace. Defaults to 10.
metric (str): Must be one of 'cosine', 'euclidean', 'dotproduct'. This is needed in order to merge results across namespaces, since the interpretation of score depends on the index metric type.
filter (Optional[Dict[str, Union[str, float, int, bool, List, dict]]], optional): Pass an optional filter to filter results based on metadata. Defaults to None.
include_values (Optional[bool], optional): Boolean field indicating whether vector values should be included with results. Defaults to None.
include_metadata (Optional[bool], optional): Boolean field indicating whether vector metadata should be included with results. Defaults to None.
Expand All @@ -568,7 +571,7 @@ def query_namespaces(
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric)

target_namespaces = set(namespaces) # dedup namespaces
async_futures = [
Expand Down
62 changes: 23 additions & 39 deletions pinecone/data/query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional, Any, Dict
from typing import List, Tuple, Optional, Any, Dict, Literal
import json
import heapq
from pinecone.core.openapi.data.models import Usage
Expand Down Expand Up @@ -88,46 +88,38 @@ def __repr__(self):
)


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self):
super().__init__(
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
)
super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 2:
def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]):
if top_k < 1:
raise QueryResultsAggregatorInvalidTopKError(top_k)

if metric in ["dotproduct", "cosine"]:
self.is_bigger_better = True
elif metric in ["euclidean"]:
self.is_bigger_better = False
else:
raise ValueError(
f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'."
)

self.top_k = top_k
self.usage_read_units = 0
self.heap: List[Tuple[float, int, object, str]] = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False
self.final_results: Optional[QueryNamespacesResults] = None

def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
# the order of the scores in the results.
for i in range(1, len(matches)):
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
return False
return True

def _dotproduct_heap_item(self, match, ns):
def _bigger_better_heap_item(self, match, ns):
# This 4-tuple is used to ensure that the heap is sorted by score followed by
# insertion order. The insertion order is used to break any ties in the score.
return (match.get("score"), -self.insertion_counter, match, ns)

def _non_dotproduct_heap_item(self, match, ns):
def _smaller_better_heap_item(self, match, ns):
return (-match.get("score"), -self.insertion_counter, match, ns)

def _process_matches(self, matches, ns, heap_item_fn):
Expand All @@ -137,10 +129,10 @@ def _process_matches(self, matches, ns, heap_item_fn):
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
# Assume we have dotproduct scores sorted in descending order
if self.is_dotproduct and match["score"] < self.heap[0][0]:
if self.is_bigger_better and match["score"] < self.heap[0][0]:
# No further matches can improve the top-K heap
break
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
elif not self.is_bigger_better and match["score"] > -self.heap[0][0]:
# No further matches can improve the top-K heap
break
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
Expand All @@ -162,18 +154,10 @@ def add_results(self, results: Dict[str, Any]):
if len(matches) == 0:
return

if self.is_dotproduct is None:
if len(matches) == 1:
# This condition should match the second time we add results containing
# only one match. We need at least two matches in a single response in order
# to infer the similarity metric
raise QueryResultsAggregregatorNotEnoughResultsError()
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
self._process_matches(matches, ns, self._dotproduct_heap_item)
if self.is_bigger_better:
self._process_matches(matches, ns, self._bigger_better_heap_item)
else:
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
self._process_matches(matches, ns, self._smaller_better_heap_item)

def get_results(self) -> QueryNamespacesResults:
if self.read:
Expand Down
5 changes: 3 additions & 2 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast, Literal

from google.protobuf import json_format

Expand Down Expand Up @@ -409,6 +409,7 @@ def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
Expand All @@ -422,7 +423,7 @@ def query_namespaces(
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric)

target_namespaces = set(namespaces) # dedup namespaces
futures = [
Expand Down
69 changes: 35 additions & 34 deletions tests/integration/data/test_query_namespaces.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import pytest
from ..helpers import random_string, poll_stats_for_namespace
from pinecone.data.query_results_aggregator import (
QueryResultsAggregatorInvalidTopKError,
QueryResultsAggregregatorNotEnoughResultsError,
)

from pinecone import Vector


class TestQueryNamespacesRest:
def test_query_namespaces(self, idx):
def test_query_namespaces(self, idx, metric):
ns_prefix = random_string(5)
ns1 = f"{ns_prefix}-ns1"
ns2 = f"{ns_prefix}-ns2"
Expand Down Expand Up @@ -50,6 +46,7 @@ def test_query_namespaces(self, idx):
results = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "drama"}},
Expand Down Expand Up @@ -84,6 +81,7 @@ def test_query_namespaces(self, idx):
results2 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3, f"{ns_prefix}-nonexistent"],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "action"}},
Expand All @@ -98,6 +96,7 @@ def test_query_namespaces(self, idx):
results3 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -110,6 +109,7 @@ def test_query_namespaces(self, idx):
results4 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "comedy"}},
Expand All @@ -122,6 +122,7 @@ def test_query_namespaces(self, idx):
results5 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -137,6 +138,7 @@ def test_query_namespaces(self, idx):
f"{ns_prefix}-nonexistent2",
f"{ns_prefix}-nonexistent3",
],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "comedy"}},
Expand All @@ -145,22 +147,7 @@ def test_query_namespaces(self, idx):
assert len(results6.matches) == 0
assert results6.usage.read_units > 0

def test_invalid_top_k(self, idx):
with pytest.raises(QueryResultsAggregatorInvalidTopKError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=["ns1", "ns2", "ns3"],
include_values=True,
include_metadata=True,
filter={},
top_k=1,
)
assert (
str(e.value)
== "Invalid top_k value 1. To aggregate results from multiple queries the top_k must be at least 2."
)

def test_unmergeable_results(self, idx):
def test_single_result_per_namespace(self, idx):
ns_prefix = random_string(5)
ns1 = f"{ns_prefix}-ns1"
ns2 = f"{ns_prefix}-ns2"
Expand All @@ -183,26 +170,27 @@ def test_unmergeable_results(self, idx):
poll_stats_for_namespace(idx, namespace=ns1, expected_count=2)
poll_stats_for_namespace(idx, namespace=ns2, expected_count=2)

with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2],
include_values=True,
include_metadata=True,
filter={"key": {"$eq": 1}},
top_k=2,
)

assert (
str(e.value)
== "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
results = idx.query_namespaces(
vector=[0.1, 0.21],
namespaces=[ns1, ns2],
metric="cosine",
include_values=True,
include_metadata=True,
filter={"key": {"$eq": 1}},
top_k=2,
)
assert len(results.matches) == 2
assert results.matches[0].id == "id1"
assert results.matches[0].namespace == ns1
assert results.matches[1].id == "id5"
assert results.matches[1].namespace == ns2

def test_missing_namespaces(self, idx):
with pytest.raises(ValueError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[],
metric="cosine",
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -214,9 +202,22 @@ def test_missing_namespaces(self, idx):
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=None,
metric="cosine",
include_values=True,
include_metadata=True,
filter={},
top_k=2,
)
assert str(e.value) == "At least one namespace must be specified"

def test_missing_metric(self, idx):
with pytest.raises(TypeError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=["ns1"],
include_values=True,
include_metadata=True,
filter={},
top_k=2,
)
assert str(e.value) == "query_namespaces() missing 1 required positional argument: 'metric'"
Loading

0 comments on commit 6d6f63e

Please sign in to comment.