From 6d6f63ecc54f3158570d326c1b7a73289cd063da Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 6 Dec 2024 13:47:41 -0500 Subject: [PATCH] Add type hints for metric kwarg --- pinecone/data/index.py | 7 +- pinecone/data/query_results_aggregator.py | 62 ++--- pinecone/grpc/index_grpc.py | 5 +- .../integration/data/test_query_namespaces.py | 69 ++--- tests/unit/test_query_results_aggregator.py | 247 ++++++++++-------- 5 files changed, 208 insertions(+), 182 deletions(-) diff --git a/pinecone/data/index.py b/pinecone/data/index.py index eb406250..138bdb91 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -2,7 +2,7 @@ import logging import json -from typing import Union, List, Optional, Dict, Any +from typing import Union, List, Optional, Dict, Any, Literal from pinecone.config import ConfigBuilder @@ -511,6 +511,7 @@ def query_namespaces( self, vector: List[float], namespaces: List[str], + metric: Literal["cosine", "euclidean", "dotproduct"], top_k: Optional[int] = None, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, @@ -540,6 +541,7 @@ def query_namespaces( combined_results = index.query_namespaces( vector=query_vec, namespaces=['ns1', 'ns2', 'ns3', 'ns4'], + metric="cosine", top_k=10, filter={'genre': {"$eq": "drama"}}, include_values=True, @@ -554,6 +556,7 @@ def query_namespaces( vector (List[float]): The query vector, must be the same length as the dimension of the index being queried. namespaces (List[str]): The list of namespaces to query. top_k (Optional[int], optional): The number of results you would like to request from each namespace. Defaults to 10. + metric (str): Must be one of 'cosine', 'euclidean', 'dotproduct'. This is needed in order to merge results across namespaces, since the interpretation of score depends on the index metric type. filter (Optional[Dict[str, Union[str, float, int, bool, List, dict]]], optional): Pass an optional filter to filter results based on metadata. Defaults to None. include_values (Optional[bool], optional): Boolean field indicating whether vector values should be included with results. Defaults to None. include_metadata (Optional[bool], optional): Boolean field indicating whether vector metadata should be included with results. Defaults to None. @@ -568,7 +571,7 @@ def query_namespaces( raise ValueError("Query vector must not be empty") overall_topk = top_k if top_k is not None else 10 - aggregator = QueryResultsAggregator(top_k=overall_topk) + aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric) target_namespaces = set(namespaces) # dedup namespaces async_futures = [ diff --git a/pinecone/data/query_results_aggregator.py b/pinecone/data/query_results_aggregator.py index 98ca77a2..0777981a 100644 --- a/pinecone/data/query_results_aggregator.py +++ b/pinecone/data/query_results_aggregator.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Optional, Any, Dict +from typing import List, Tuple, Optional, Any, Dict, Literal import json import heapq from pinecone.core.openapi.data.models import Usage @@ -88,46 +88,38 @@ def __repr__(self): ) -class QueryResultsAggregregatorNotEnoughResultsError(Exception): - def __init__(self): - super().__init__( - "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." - ) - - class QueryResultsAggregatorInvalidTopKError(Exception): def __init__(self, top_k: int): - super().__init__( - f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." - ) + super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.") class QueryResultsAggregator: - def __init__(self, top_k: int): - if top_k < 2: + def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]): + if top_k < 1: raise QueryResultsAggregatorInvalidTopKError(top_k) + + if metric in ["dotproduct", "cosine"]: + self.is_bigger_better = True + elif metric in ["euclidean"]: + self.is_bigger_better = False + else: + raise ValueError( + f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'." + ) + self.top_k = top_k self.usage_read_units = 0 self.heap: List[Tuple[float, int, object, str]] = [] self.insertion_counter = 0 - self.is_dotproduct = None self.read = False self.final_results: Optional[QueryNamespacesResults] = None - def _is_dotproduct_index(self, matches): - # The interpretation of the score depends on the similar metric used. - # Unlike other index types, in indexes configured for dotproduct, - # a higher score is better. We have to infer this is the case by inspecting - # the order of the scores in the results. - for i in range(1, len(matches)): - if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase - return False - return True - - def _dotproduct_heap_item(self, match, ns): + def _bigger_better_heap_item(self, match, ns): + # This 4-tuple is used to ensure that the heap is sorted by score followed by + # insertion order. The insertion order is used to break any ties in the score. return (match.get("score"), -self.insertion_counter, match, ns) - def _non_dotproduct_heap_item(self, match, ns): + def _smaller_better_heap_item(self, match, ns): return (-match.get("score"), -self.insertion_counter, match, ns) def _process_matches(self, matches, ns, heap_item_fn): @@ -137,10 +129,10 @@ def _process_matches(self, matches, ns, heap_item_fn): heapq.heappush(self.heap, heap_item_fn(match, ns)) else: # Assume we have dotproduct scores sorted in descending order - if self.is_dotproduct and match["score"] < self.heap[0][0]: + if self.is_bigger_better and match["score"] < self.heap[0][0]: # No further matches can improve the top-K heap break - elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: + elif not self.is_bigger_better and match["score"] > -self.heap[0][0]: # No further matches can improve the top-K heap break heapq.heappushpop(self.heap, heap_item_fn(match, ns)) @@ -162,18 +154,10 @@ def add_results(self, results: Dict[str, Any]): if len(matches) == 0: return - if self.is_dotproduct is None: - if len(matches) == 1: - # This condition should match the second time we add results containing - # only one match. We need at least two matches in a single response in order - # to infer the similarity metric - raise QueryResultsAggregregatorNotEnoughResultsError() - self.is_dotproduct = self._is_dotproduct_index(matches) - - if self.is_dotproduct: - self._process_matches(matches, ns, self._dotproduct_heap_item) + if self.is_bigger_better: + self._process_matches(matches, ns, self._bigger_better_heap_item) else: - self._process_matches(matches, ns, self._non_dotproduct_heap_item) + self._process_matches(matches, ns, self._smaller_better_heap_item) def get_results(self) -> QueryNamespacesResults: if self.read: diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 6791ae68..46f541d9 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast +from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast, Literal from google.protobuf import json_format @@ -409,6 +409,7 @@ def query_namespaces( self, vector: List[float], namespaces: List[str], + metric: Literal["cosine", "euclidean", "dotproduct"], top_k: Optional[int] = None, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, @@ -422,7 +423,7 @@ def query_namespaces( raise ValueError("Query vector must not be empty") overall_topk = top_k if top_k is not None else 10 - aggregator = QueryResultsAggregator(top_k=overall_topk) + aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric) target_namespaces = set(namespaces) # dedup namespaces futures = [ diff --git a/tests/integration/data/test_query_namespaces.py b/tests/integration/data/test_query_namespaces.py index 414cea69..ba4ed821 100644 --- a/tests/integration/data/test_query_namespaces.py +++ b/tests/integration/data/test_query_namespaces.py @@ -1,15 +1,11 @@ import pytest from ..helpers import random_string, poll_stats_for_namespace -from pinecone.data.query_results_aggregator import ( - QueryResultsAggregatorInvalidTopKError, - QueryResultsAggregregatorNotEnoughResultsError, -) from pinecone import Vector class TestQueryNamespacesRest: - def test_query_namespaces(self, idx): + def test_query_namespaces(self, idx, metric): ns_prefix = random_string(5) ns1 = f"{ns_prefix}-ns1" ns2 = f"{ns_prefix}-ns2" @@ -50,6 +46,7 @@ def test_query_namespaces(self, idx): results = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "drama"}}, @@ -84,6 +81,7 @@ def test_query_namespaces(self, idx): results2 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3, f"{ns_prefix}-nonexistent"], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "action"}}, @@ -98,6 +96,7 @@ def test_query_namespaces(self, idx): results3 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={}, @@ -110,6 +109,7 @@ def test_query_namespaces(self, idx): results4 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "comedy"}}, @@ -122,6 +122,7 @@ def test_query_namespaces(self, idx): results5 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={}, @@ -137,6 +138,7 @@ def test_query_namespaces(self, idx): f"{ns_prefix}-nonexistent2", f"{ns_prefix}-nonexistent3", ], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "comedy"}}, @@ -145,22 +147,7 @@ def test_query_namespaces(self, idx): assert len(results6.matches) == 0 assert results6.usage.read_units > 0 - def test_invalid_top_k(self, idx): - with pytest.raises(QueryResultsAggregatorInvalidTopKError) as e: - idx.query_namespaces( - vector=[0.1, 0.2], - namespaces=["ns1", "ns2", "ns3"], - include_values=True, - include_metadata=True, - filter={}, - top_k=1, - ) - assert ( - str(e.value) - == "Invalid top_k value 1. To aggregate results from multiple queries the top_k must be at least 2." - ) - - def test_unmergeable_results(self, idx): + def test_single_result_per_namespace(self, idx): ns_prefix = random_string(5) ns1 = f"{ns_prefix}-ns1" ns2 = f"{ns_prefix}-ns2" @@ -183,26 +170,27 @@ def test_unmergeable_results(self, idx): poll_stats_for_namespace(idx, namespace=ns1, expected_count=2) poll_stats_for_namespace(idx, namespace=ns2, expected_count=2) - with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError) as e: - idx.query_namespaces( - vector=[0.1, 0.2], - namespaces=[ns1, ns2], - include_values=True, - include_metadata=True, - filter={"key": {"$eq": 1}}, - top_k=2, - ) - - assert ( - str(e.value) - == "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." + results = idx.query_namespaces( + vector=[0.1, 0.21], + namespaces=[ns1, ns2], + metric="cosine", + include_values=True, + include_metadata=True, + filter={"key": {"$eq": 1}}, + top_k=2, ) + assert len(results.matches) == 2 + assert results.matches[0].id == "id1" + assert results.matches[0].namespace == ns1 + assert results.matches[1].id == "id5" + assert results.matches[1].namespace == ns2 def test_missing_namespaces(self, idx): with pytest.raises(ValueError) as e: idx.query_namespaces( vector=[0.1, 0.2], namespaces=[], + metric="cosine", include_values=True, include_metadata=True, filter={}, @@ -214,9 +202,22 @@ def test_missing_namespaces(self, idx): idx.query_namespaces( vector=[0.1, 0.2], namespaces=None, + metric="cosine", include_values=True, include_metadata=True, filter={}, top_k=2, ) assert str(e.value) == "At least one namespace must be specified" + + def test_missing_metric(self, idx): + with pytest.raises(TypeError) as e: + idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=["ns1"], + include_values=True, + include_metadata=True, + filter={}, + top_k=2, + ) + assert str(e.value) == "query_namespaces() missing 1 required positional argument: 'metric'" diff --git a/tests/unit/test_query_results_aggregator.py b/tests/unit/test_query_results_aggregator.py index c482ca15..b40a11d2 100644 --- a/tests/unit/test_query_results_aggregator.py +++ b/tests/unit/test_query_results_aggregator.py @@ -1,15 +1,14 @@ from pinecone.data.query_results_aggregator import ( QueryResultsAggregator, QueryResultsAggregatorInvalidTopKError, - QueryResultsAggregregatorNotEnoughResultsError, ) import random import pytest -class TestQueryResultsAggregator: +class TestQueryResultsAggregator_EuclideanIndex: def test_keeps_running_usage_total(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [ @@ -53,7 +52,7 @@ def test_keeps_running_usage_total(self): assert results.get("usage", {}).get("read_units") == results.usage.read_units def test_inserting_duplicate_scores_stable_ordering(self): - aggregator = QueryResultsAggregator(top_k=5) + aggregator = QueryResultsAggregator(top_k=5, metric="euclidean") results1 = { "matches": [ @@ -95,47 +94,8 @@ def test_inserting_duplicate_scores_stable_ordering(self): assert results.matches[4].id == "4" # 0.22 assert results.matches[4].namespace == "ns1" - def test_correctly_handles_dotproduct_metric(self): - # For this index metric, the higher the score, the more similar the vectors are. - # We have to infer that we have this type of index by seeing whether scores are - # sorted in descending or ascending order. - aggregator = QueryResultsAggregator(top_k=3) - - desc_results1 = { - "matches": [ - {"id": "1", "score": 0.9, "values": []}, - {"id": "2", "score": 0.8, "values": []}, - {"id": "3", "score": 0.7, "values": []}, - {"id": "4", "score": 0.6, "values": []}, - {"id": "5", "score": 0.5, "values": []}, - ], - "usage": {"readUnits": 5}, - "namespace": "ns1", - } - aggregator.add_results(desc_results1) - - results2 = { - "matches": [ - {"id": "7", "score": 0.77, "values": []}, - {"id": "8", "score": 0.88, "values": []}, - {"id": "9", "score": 0.99, "values": []}, - {"id": "10", "score": 0.1010, "values": []}, - {"id": "11", "score": 0.1111, "values": []}, - ], - "usage": {"readUnits": 7}, - "namespace": "ns2", - } - aggregator.add_results(results2) - - results = aggregator.get_results() - assert results.usage.read_units == 12 - assert len(results.matches) == 3 - assert results.matches[0].id == "9" # 0.99 - assert results.matches[1].id == "1" # 0.9 - assert results.matches[2].id == "8" # 0.88 - def test_still_correct_with_early_return(self): - aggregator = QueryResultsAggregator(top_k=5) + aggregator = QueryResultsAggregator(top_k=5, metric="euclidean") results1 = { "matches": [ @@ -172,8 +132,8 @@ def test_still_correct_with_early_return(self): assert results.matches[3].id == "2" assert results.matches[4].id == "3" - def test_still_correct_with_early_return_generated_nont_dotproduct(self): - aggregator = QueryResultsAggregator(top_k=1000) + def test_fewer_results_than_topk(self): + aggregator = QueryResultsAggregator(top_k=1000, metric="euclidean") matches1 = [ {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) ] @@ -223,8 +183,93 @@ def test_still_correct_with_early_return_generated_nont_dotproduct(self): assert results.matches[3].score == merged_matches[3]["score"] assert results.matches[4].score == merged_matches[4]["score"] - def test_still_correct_with_early_return_generated_dotproduct(self): - aggregator = QueryResultsAggregator(top_k=1000) + +class TestQueryResultsAggregator_CosineOrDotproductIndex: + def test_inserting_duplicate_scores_stable_ordering(self): + aggregator = QueryResultsAggregator(top_k=6, metric="cosine") + + results1 = { + "matches": [ + {"id": "1", "score": 0.94, "values": []}, + {"id": "3", "score": 0.94, "values": []}, + {"id": "2", "score": 0.94, "values": []}, + {"id": "4", "score": 0.88, "values": []}, + {"id": "5", "score": 0.88, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.93, "values": []}, + {"id": "7", "score": 0.93, "values": []}, + {"id": "8", "score": 0.85, "values": []}, + {"id": "9", "score": 0.85, "values": []}, + {"id": "10", "score": 0.80, "values": []}, + ], + "usage": {"readUnits": 8}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 13 + assert len(results.matches) == 6 + assert results.matches[0].id == "1" + assert results.matches[0].namespace == "ns1" + assert results.matches[1].id == "3" + assert results.matches[1].namespace == "ns1" + assert results.matches[2].id == "2" + assert results.matches[2].namespace == "ns1" + assert results.matches[3].id == "6" + assert results.matches[3].namespace == "ns2" + assert results.matches[4].id == "7" + assert results.matches[4].namespace == "ns2" + assert results.matches[5].id == "4" + assert results.matches[5].namespace == "ns1" + + @pytest.mark.parametrize("index_metric", ["cosine", "dotproduct"]) + def test_correctly_handles_dotproduct_and_cosine_metric(self, index_metric): + # For this index metric, the higher the score, the more similar the vectors are. + aggregator = QueryResultsAggregator(top_k=3, metric=index_metric) + + desc_results1 = { + "matches": [ + {"id": "1", "score": 0.9, "values": []}, + {"id": "2", "score": 0.8, "values": []}, + {"id": "3", "score": 0.7, "values": []}, + {"id": "4", "score": 0.6, "values": []}, + {"id": "5", "score": 0.5, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(desc_results1) + + results2 = { + "matches": [ + {"id": "7", "score": 0.77, "values": []}, + {"id": "8", "score": 0.88, "values": []}, + {"id": "9", "score": 0.99, "values": []}, + {"id": "10", "score": 0.1010, "values": []}, + {"id": "11", "score": 0.1111, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "9" # 0.99 + assert results.matches[1].id == "1" # 0.9 + assert results.matches[2].id == "8" # 0.88 + + def test_lots_of_results(self): + aggregator = QueryResultsAggregator(top_k=1000, metric="dotproduct") matches1 = [ {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) ] @@ -277,7 +322,7 @@ def test_still_correct_with_early_return_generated_dotproduct(self): class TestQueryResultsAggregatorOutputUX: def test_can_interact_with_attributes(self): - aggregator = QueryResultsAggregator(top_k=2) + aggregator = QueryResultsAggregator(top_k=2, metric="euclidean") results1 = { "matches": [ { @@ -306,7 +351,7 @@ def test_can_interact_with_attributes(self): assert results.matches[0].values == [0.31, 0.32, 0.33, 0.34, 0.35, 0.36] def test_can_interact_like_dict(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [ { @@ -345,13 +390,13 @@ def test_can_interact_like_dict(self): assert results["matches"][0]["score"] == 0.3 def test_can_print_empty_results_without_error(self, capsys): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results = aggregator.get_results() print(results) capsys.readouterr() def test_can_print_results_containing_None_without_error(self, capsys): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [ {"id": "1", "score": 0.1}, @@ -369,7 +414,7 @@ def test_can_print_results_containing_None_without_error(self, capsys): capsys.readouterr() def test_can_print_results_containing_short_vectors(self, capsys): - aggregator = QueryResultsAggregator(top_k=4) + aggregator = QueryResultsAggregator(top_k=4, metric="euclidean") results1 = { "matches": [ {"id": "1", "score": 0.1, "values": [0.31]}, @@ -386,7 +431,7 @@ def test_can_print_results_containing_short_vectors(self, capsys): capsys.readouterr() def test_can_print_complete_results_without_error(self, capsys): - aggregator = QueryResultsAggregator(top_k=2) + aggregator = QueryResultsAggregator(top_k=2, metric="euclidean") results1 = { "matches": [ { @@ -421,29 +466,45 @@ def test_can_print_complete_results_without_error(self, capsys): class TestQueryAggregatorEdgeCases: def test_topK_too_small(self): with pytest.raises(QueryResultsAggregatorInvalidTopKError): - QueryResultsAggregator(top_k=0) - with pytest.raises(QueryResultsAggregatorInvalidTopKError): - QueryResultsAggregator(top_k=1) + QueryResultsAggregator(top_k=0, metric="euclidean") + + def test_results_never_added(self): + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 0 + assert len(results.matches) == 0 - def test_matches_too_small(self): - aggregator = QueryResultsAggregator(top_k=3) + @pytest.mark.parametrize("index_metric", ["euclidean", "cosine", "dotproduct"]) + def test_tie_breaking(self, index_metric): + aggregator = QueryResultsAggregator(top_k=3, metric=index_metric) results1 = { - "matches": [{"id": "1", "score": 0.1}], + "matches": [{"id": "5", "score": 0.9}, {"id": "2", "score": 0.9}], "usage": {"readUnits": 5}, "namespace": "ns1", } - with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): - aggregator.add_results(results1) - def test_empty_results(self): - aggregator = QueryResultsAggregator(top_k=3) + results2 = { + "matches": [{"id": "3", "score": 0.9}, {"id": "4", "score": 0.9}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results1) + aggregator.add_results(results2) results = aggregator.get_results() - assert results is not None - assert results.usage.read_units == 0 - assert len(results.matches) == 0 + assert results.usage.read_units == 10 + assert len(results.matches) == 3 + + # Maintains order results were added when resolving ties + assert results.matches[0].id == "5" + assert results.matches[0].namespace == "ns1" + assert results.matches[1].id == "2" + assert results.matches[1].namespace == "ns1" + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns2" def test_empty_results_with_usage(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) @@ -455,7 +516,7 @@ def test_empty_results_with_usage(self): assert len(results.matches) == 0 def test_exactly_one_result(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], "usage": {"readUnits": 5}, @@ -482,56 +543,32 @@ def test_exactly_one_result(self): assert results.matches[2].namespace == "ns2" assert results.matches[2].score == 0.2 - def test_two_result_sets_with_single_result_errors(self): - with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): - aggregator = QueryResultsAggregator(top_k=3) - results1 = { - "matches": [{"id": "1", "score": 0.1}], - "usage": {"readUnits": 5}, - "namespace": "ns1", - } - aggregator.add_results(results1) - results2 = { - "matches": [{"id": "2", "score": 0.01}], - "usage": {"readUnits": 5}, - "namespace": "ns2", - } - aggregator.add_results(results2) - - def test_single_result_after_index_type_known_no_error(self): - aggregator = QueryResultsAggregator(top_k=3) - - results3 = { - "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], - "usage": {"readUnits": 5}, - "namespace": "ns3", - } - aggregator.add_results(results3) - + def test_two_result_sets_with_single_result(self): + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [{"id": "1", "score": 0.1}], "usage": {"readUnits": 5}, "namespace": "ns1", } aggregator.add_results(results1) - results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"} + results2 = { + "matches": [{"id": "2", "score": 0.01}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } aggregator.add_results(results2) - results = aggregator.get_results() - assert results.usage.read_units == 15 - assert len(results.matches) == 3 + assert results.usage.read_units == 10 + assert len(results.matches) == 2 assert results.matches[0].id == "2" - assert results.matches[0].namespace == "ns3" + assert results.matches[0].namespace == "ns2" assert results.matches[0].score == 0.01 assert results.matches[1].id == "1" assert results.matches[1].namespace == "ns1" assert results.matches[1].score == 0.1 - assert results.matches[2].id == "3" - assert results.matches[2].namespace == "ns3" - assert results.matches[2].score == 0.2 def test_all_empty_results(self): - aggregator = QueryResultsAggregator(top_k=10) + aggregator = QueryResultsAggregator(top_k=10, metric="euclidean") aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) @@ -543,7 +580,7 @@ def test_all_empty_results(self): assert len(results.matches) == 0 def test_some_empty_results(self): - aggregator = QueryResultsAggregator(top_k=10) + aggregator = QueryResultsAggregator(top_k=10, metric="euclidean") results2 = { "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], "usage": {"readUnits": 5},