From bb9d9c4135c0938e2b4e96068cfe6b02ea40732f Mon Sep 17 00:00:00 2001 From: Jennifer Hamon Date: Fri, 6 Dec 2024 18:18:11 -0500 Subject: [PATCH] [Bug] query_namespaces can handle single result (#421) In order to merge results across multiple queries, the SDK must know which similarity metric an index is using. For dotproduct and cosine indexes, a larger score is better while for euclidean a smaller score is better. Unfortunately the data plane API does not currently expose the metric type and a separate call to the control plane to find out seems undesirable from a resiliency and performance perspective. As a workaround, in the initial implementation of `query_namespaces` the SDK would infer the similarity metric needed to merge results by seeing whether the scores of query results were ascending or descending. This worked well, but imposes an implicit limitation that there must be at least 2 results returned. We initially believed this would not be a problem but have since learned that applications using filtering can sometimes filter out all or most results. So an approach that has the user explicitly telling the SDK what similarity metric is being used is preferred to handle these edge cases with 1 or 0 results. - Add a required kwarg to `query_namespaces` to specify the index similarity metric. - Modify `QueryResultsAggregator` to use this similarity metric, and strip out code that was involved in inferring whether results were ascending or descending. - Adjust integration tests to pass new metric kwarg. Except for adding the new kwarg, query_namespaces integration tests did not need to change which indicates the underlying behavior is still working as before. - [x] Bug fix (non-breaking change which fixes an issue) --- pinecone/data/index.py | 5 +- pinecone/data/interfaces.py | 2 + pinecone/data/query_results_aggregator.py | 62 ++--- pinecone/grpc/index_grpc.py | 5 +- .../integration/data/test_query_namespaces.py | 69 ++--- tests/unit/test_query_results_aggregator.py | 247 ++++++++++-------- 6 files changed, 208 insertions(+), 182 deletions(-) diff --git a/pinecone/data/index.py b/pinecone/data/index.py index 3f74568d..90f7f827 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -2,7 +2,7 @@ import logging import json -from typing import Union, List, Optional, Dict, Any +from typing import Union, List, Optional, Dict, Any, Literal from pinecone.config import ConfigBuilder @@ -311,6 +311,7 @@ def query_namespaces( self, vector: List[float], namespaces: List[str], + metric: Literal["cosine", "euclidean", "dotproduct"], top_k: Optional[int] = None, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, @@ -326,7 +327,7 @@ def query_namespaces( raise ValueError("Query vector must not be empty") overall_topk = top_k if top_k is not None else 10 - aggregator = QueryResultsAggregator(top_k=overall_topk) + aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric) target_namespaces = set(namespaces) # dedup namespaces async_futures = [ diff --git a/pinecone/data/interfaces.py b/pinecone/data/interfaces.py index 9589099c..117155d9 100644 --- a/pinecone/data/interfaces.py +++ b/pinecone/data/interfaces.py @@ -254,6 +254,7 @@ def query_namespaces( combined_results = index.query_namespaces( vector=query_vec, namespaces=['ns1', 'ns2', 'ns3', 'ns4'], + metric="cosine", top_k=10, filter={'genre': {"$eq": "drama"}}, include_values=True, @@ -268,6 +269,7 @@ def query_namespaces( vector (List[float]): The query vector, must be the same length as the dimension of the index being queried. namespaces (List[str]): The list of namespaces to query. top_k (Optional[int], optional): The number of results you would like to request from each namespace. Defaults to 10. + metric (str): Must be one of 'cosine', 'euclidean', 'dotproduct'. This is needed in order to merge results across namespaces, since the interpretation of score depends on the index metric type. filter (Optional[Dict[str, Union[str, float, int, bool, List, dict]]], optional): Pass an optional filter to filter results based on metadata. Defaults to None. include_values (Optional[bool], optional): Boolean field indicating whether vector values should be included with results. Defaults to None. include_metadata (Optional[bool], optional): Boolean field indicating whether vector metadata should be included with results. Defaults to None. diff --git a/pinecone/data/query_results_aggregator.py b/pinecone/data/query_results_aggregator.py index 651477e2..f7c0092e 100644 --- a/pinecone/data/query_results_aggregator.py +++ b/pinecone/data/query_results_aggregator.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Optional, Any, Dict +from typing import List, Tuple, Optional, Any, Dict, Literal import json import heapq from pinecone.core.openapi.db_data.models import Usage @@ -88,46 +88,38 @@ def __repr__(self): ) -class QueryResultsAggregregatorNotEnoughResultsError(Exception): - def __init__(self): - super().__init__( - "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." - ) - - class QueryResultsAggregatorInvalidTopKError(Exception): def __init__(self, top_k: int): - super().__init__( - f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." - ) + super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.") class QueryResultsAggregator: - def __init__(self, top_k: int): - if top_k < 2: + def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]): + if top_k < 1: raise QueryResultsAggregatorInvalidTopKError(top_k) + + if metric in ["dotproduct", "cosine"]: + self.is_bigger_better = True + elif metric in ["euclidean"]: + self.is_bigger_better = False + else: + raise ValueError( + f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'." + ) + self.top_k = top_k self.usage_read_units = 0 self.heap: List[Tuple[float, int, object, str]] = [] self.insertion_counter = 0 - self.is_dotproduct = None self.read = False self.final_results: Optional[QueryNamespacesResults] = None - def _is_dotproduct_index(self, matches): - # The interpretation of the score depends on the similar metric used. - # Unlike other index types, in indexes configured for dotproduct, - # a higher score is better. We have to infer this is the case by inspecting - # the order of the scores in the results. - for i in range(1, len(matches)): - if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase - return False - return True - - def _dotproduct_heap_item(self, match, ns): + def _bigger_better_heap_item(self, match, ns): + # This 4-tuple is used to ensure that the heap is sorted by score followed by + # insertion order. The insertion order is used to break any ties in the score. return (match.get("score"), -self.insertion_counter, match, ns) - def _non_dotproduct_heap_item(self, match, ns): + def _smaller_better_heap_item(self, match, ns): return (-match.get("score"), -self.insertion_counter, match, ns) def _process_matches(self, matches, ns, heap_item_fn): @@ -137,10 +129,10 @@ def _process_matches(self, matches, ns, heap_item_fn): heapq.heappush(self.heap, heap_item_fn(match, ns)) else: # Assume we have dotproduct scores sorted in descending order - if self.is_dotproduct and match["score"] < self.heap[0][0]: + if self.is_bigger_better and match["score"] < self.heap[0][0]: # No further matches can improve the top-K heap break - elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: + elif not self.is_bigger_better and match["score"] > -self.heap[0][0]: # No further matches can improve the top-K heap break heapq.heappushpop(self.heap, heap_item_fn(match, ns)) @@ -162,18 +154,10 @@ def add_results(self, results: Dict[str, Any]): if len(matches) == 0: return - if self.is_dotproduct is None: - if len(matches) == 1: - # This condition should match the second time we add results containing - # only one match. We need at least two matches in a single response in order - # to infer the similarity metric - raise QueryResultsAggregregatorNotEnoughResultsError() - self.is_dotproduct = self._is_dotproduct_index(matches) - - if self.is_dotproduct: - self._process_matches(matches, ns, self._dotproduct_heap_item) + if self.is_bigger_better: + self._process_matches(matches, ns, self._bigger_better_heap_item) else: - self._process_matches(matches, ns, self._non_dotproduct_heap_item) + self._process_matches(matches, ns, self._smaller_better_heap_item) def get_results(self) -> QueryNamespacesResults: if self.read: diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 99807edf..170c17ac 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast +from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast, Literal from google.protobuf import json_format @@ -409,6 +409,7 @@ def query_namespaces( self, vector: List[float], namespaces: List[str], + metric: Literal["cosine", "euclidean", "dotproduct"], top_k: Optional[int] = None, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, @@ -422,7 +423,7 @@ def query_namespaces( raise ValueError("Query vector must not be empty") overall_topk = top_k if top_k is not None else 10 - aggregator = QueryResultsAggregator(top_k=overall_topk) + aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric) target_namespaces = set(namespaces) # dedup namespaces futures = [ diff --git a/tests/integration/data/test_query_namespaces.py b/tests/integration/data/test_query_namespaces.py index 414cea69..7100f573 100644 --- a/tests/integration/data/test_query_namespaces.py +++ b/tests/integration/data/test_query_namespaces.py @@ -1,15 +1,11 @@ import pytest from ..helpers import random_string, poll_stats_for_namespace -from pinecone.data.query_results_aggregator import ( - QueryResultsAggregatorInvalidTopKError, - QueryResultsAggregregatorNotEnoughResultsError, -) from pinecone import Vector class TestQueryNamespacesRest: - def test_query_namespaces(self, idx): + def test_query_namespaces(self, idx, metric): ns_prefix = random_string(5) ns1 = f"{ns_prefix}-ns1" ns2 = f"{ns_prefix}-ns2" @@ -50,6 +46,7 @@ def test_query_namespaces(self, idx): results = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "drama"}}, @@ -84,6 +81,7 @@ def test_query_namespaces(self, idx): results2 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3, f"{ns_prefix}-nonexistent"], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "action"}}, @@ -98,6 +96,7 @@ def test_query_namespaces(self, idx): results3 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={}, @@ -110,6 +109,7 @@ def test_query_namespaces(self, idx): results4 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "comedy"}}, @@ -122,6 +122,7 @@ def test_query_namespaces(self, idx): results5 = idx.query_namespaces( vector=[0.1, 0.2], namespaces=[ns1, ns2, ns3], + metric=metric, include_values=True, include_metadata=True, filter={}, @@ -137,6 +138,7 @@ def test_query_namespaces(self, idx): f"{ns_prefix}-nonexistent2", f"{ns_prefix}-nonexistent3", ], + metric=metric, include_values=True, include_metadata=True, filter={"genre": {"$eq": "comedy"}}, @@ -145,22 +147,7 @@ def test_query_namespaces(self, idx): assert len(results6.matches) == 0 assert results6.usage.read_units > 0 - def test_invalid_top_k(self, idx): - with pytest.raises(QueryResultsAggregatorInvalidTopKError) as e: - idx.query_namespaces( - vector=[0.1, 0.2], - namespaces=["ns1", "ns2", "ns3"], - include_values=True, - include_metadata=True, - filter={}, - top_k=1, - ) - assert ( - str(e.value) - == "Invalid top_k value 1. To aggregate results from multiple queries the top_k must be at least 2." - ) - - def test_unmergeable_results(self, idx): + def test_single_result_per_namespace(self, idx): ns_prefix = random_string(5) ns1 = f"{ns_prefix}-ns1" ns2 = f"{ns_prefix}-ns2" @@ -183,26 +170,27 @@ def test_unmergeable_results(self, idx): poll_stats_for_namespace(idx, namespace=ns1, expected_count=2) poll_stats_for_namespace(idx, namespace=ns2, expected_count=2) - with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError) as e: - idx.query_namespaces( - vector=[0.1, 0.2], - namespaces=[ns1, ns2], - include_values=True, - include_metadata=True, - filter={"key": {"$eq": 1}}, - top_k=2, - ) - - assert ( - str(e.value) - == "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." + results = idx.query_namespaces( + vector=[0.1, 0.21], + namespaces=[ns1, ns2], + metric="cosine", + include_values=True, + include_metadata=True, + filter={"key": {"$eq": 1}}, + top_k=2, ) + assert len(results.matches) == 2 + assert results.matches[0].id == "id1" + assert results.matches[0].namespace == ns1 + assert results.matches[1].id == "id5" + assert results.matches[1].namespace == ns2 def test_missing_namespaces(self, idx): with pytest.raises(ValueError) as e: idx.query_namespaces( vector=[0.1, 0.2], namespaces=[], + metric="cosine", include_values=True, include_metadata=True, filter={}, @@ -214,9 +202,22 @@ def test_missing_namespaces(self, idx): idx.query_namespaces( vector=[0.1, 0.2], namespaces=None, + metric="cosine", include_values=True, include_metadata=True, filter={}, top_k=2, ) assert str(e.value) == "At least one namespace must be specified" + + def test_missing_metric(self, idx): + with pytest.raises(TypeError) as e: + idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=["ns1"], + include_values=True, + include_metadata=True, + filter={}, + top_k=2, + ) + assert "query_namespaces() missing 1 required positional argument: 'metric'" in str(e.value) diff --git a/tests/unit/test_query_results_aggregator.py b/tests/unit/test_query_results_aggregator.py index c482ca15..b40a11d2 100644 --- a/tests/unit/test_query_results_aggregator.py +++ b/tests/unit/test_query_results_aggregator.py @@ -1,15 +1,14 @@ from pinecone.data.query_results_aggregator import ( QueryResultsAggregator, QueryResultsAggregatorInvalidTopKError, - QueryResultsAggregregatorNotEnoughResultsError, ) import random import pytest -class TestQueryResultsAggregator: +class TestQueryResultsAggregator_EuclideanIndex: def test_keeps_running_usage_total(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [ @@ -53,7 +52,7 @@ def test_keeps_running_usage_total(self): assert results.get("usage", {}).get("read_units") == results.usage.read_units def test_inserting_duplicate_scores_stable_ordering(self): - aggregator = QueryResultsAggregator(top_k=5) + aggregator = QueryResultsAggregator(top_k=5, metric="euclidean") results1 = { "matches": [ @@ -95,47 +94,8 @@ def test_inserting_duplicate_scores_stable_ordering(self): assert results.matches[4].id == "4" # 0.22 assert results.matches[4].namespace == "ns1" - def test_correctly_handles_dotproduct_metric(self): - # For this index metric, the higher the score, the more similar the vectors are. - # We have to infer that we have this type of index by seeing whether scores are - # sorted in descending or ascending order. - aggregator = QueryResultsAggregator(top_k=3) - - desc_results1 = { - "matches": [ - {"id": "1", "score": 0.9, "values": []}, - {"id": "2", "score": 0.8, "values": []}, - {"id": "3", "score": 0.7, "values": []}, - {"id": "4", "score": 0.6, "values": []}, - {"id": "5", "score": 0.5, "values": []}, - ], - "usage": {"readUnits": 5}, - "namespace": "ns1", - } - aggregator.add_results(desc_results1) - - results2 = { - "matches": [ - {"id": "7", "score": 0.77, "values": []}, - {"id": "8", "score": 0.88, "values": []}, - {"id": "9", "score": 0.99, "values": []}, - {"id": "10", "score": 0.1010, "values": []}, - {"id": "11", "score": 0.1111, "values": []}, - ], - "usage": {"readUnits": 7}, - "namespace": "ns2", - } - aggregator.add_results(results2) - - results = aggregator.get_results() - assert results.usage.read_units == 12 - assert len(results.matches) == 3 - assert results.matches[0].id == "9" # 0.99 - assert results.matches[1].id == "1" # 0.9 - assert results.matches[2].id == "8" # 0.88 - def test_still_correct_with_early_return(self): - aggregator = QueryResultsAggregator(top_k=5) + aggregator = QueryResultsAggregator(top_k=5, metric="euclidean") results1 = { "matches": [ @@ -172,8 +132,8 @@ def test_still_correct_with_early_return(self): assert results.matches[3].id == "2" assert results.matches[4].id == "3" - def test_still_correct_with_early_return_generated_nont_dotproduct(self): - aggregator = QueryResultsAggregator(top_k=1000) + def test_fewer_results_than_topk(self): + aggregator = QueryResultsAggregator(top_k=1000, metric="euclidean") matches1 = [ {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) ] @@ -223,8 +183,93 @@ def test_still_correct_with_early_return_generated_nont_dotproduct(self): assert results.matches[3].score == merged_matches[3]["score"] assert results.matches[4].score == merged_matches[4]["score"] - def test_still_correct_with_early_return_generated_dotproduct(self): - aggregator = QueryResultsAggregator(top_k=1000) + +class TestQueryResultsAggregator_CosineOrDotproductIndex: + def test_inserting_duplicate_scores_stable_ordering(self): + aggregator = QueryResultsAggregator(top_k=6, metric="cosine") + + results1 = { + "matches": [ + {"id": "1", "score": 0.94, "values": []}, + {"id": "3", "score": 0.94, "values": []}, + {"id": "2", "score": 0.94, "values": []}, + {"id": "4", "score": 0.88, "values": []}, + {"id": "5", "score": 0.88, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.93, "values": []}, + {"id": "7", "score": 0.93, "values": []}, + {"id": "8", "score": 0.85, "values": []}, + {"id": "9", "score": 0.85, "values": []}, + {"id": "10", "score": 0.80, "values": []}, + ], + "usage": {"readUnits": 8}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 13 + assert len(results.matches) == 6 + assert results.matches[0].id == "1" + assert results.matches[0].namespace == "ns1" + assert results.matches[1].id == "3" + assert results.matches[1].namespace == "ns1" + assert results.matches[2].id == "2" + assert results.matches[2].namespace == "ns1" + assert results.matches[3].id == "6" + assert results.matches[3].namespace == "ns2" + assert results.matches[4].id == "7" + assert results.matches[4].namespace == "ns2" + assert results.matches[5].id == "4" + assert results.matches[5].namespace == "ns1" + + @pytest.mark.parametrize("index_metric", ["cosine", "dotproduct"]) + def test_correctly_handles_dotproduct_and_cosine_metric(self, index_metric): + # For this index metric, the higher the score, the more similar the vectors are. + aggregator = QueryResultsAggregator(top_k=3, metric=index_metric) + + desc_results1 = { + "matches": [ + {"id": "1", "score": 0.9, "values": []}, + {"id": "2", "score": 0.8, "values": []}, + {"id": "3", "score": 0.7, "values": []}, + {"id": "4", "score": 0.6, "values": []}, + {"id": "5", "score": 0.5, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(desc_results1) + + results2 = { + "matches": [ + {"id": "7", "score": 0.77, "values": []}, + {"id": "8", "score": 0.88, "values": []}, + {"id": "9", "score": 0.99, "values": []}, + {"id": "10", "score": 0.1010, "values": []}, + {"id": "11", "score": 0.1111, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "9" # 0.99 + assert results.matches[1].id == "1" # 0.9 + assert results.matches[2].id == "8" # 0.88 + + def test_lots_of_results(self): + aggregator = QueryResultsAggregator(top_k=1000, metric="dotproduct") matches1 = [ {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) ] @@ -277,7 +322,7 @@ def test_still_correct_with_early_return_generated_dotproduct(self): class TestQueryResultsAggregatorOutputUX: def test_can_interact_with_attributes(self): - aggregator = QueryResultsAggregator(top_k=2) + aggregator = QueryResultsAggregator(top_k=2, metric="euclidean") results1 = { "matches": [ { @@ -306,7 +351,7 @@ def test_can_interact_with_attributes(self): assert results.matches[0].values == [0.31, 0.32, 0.33, 0.34, 0.35, 0.36] def test_can_interact_like_dict(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [ { @@ -345,13 +390,13 @@ def test_can_interact_like_dict(self): assert results["matches"][0]["score"] == 0.3 def test_can_print_empty_results_without_error(self, capsys): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results = aggregator.get_results() print(results) capsys.readouterr() def test_can_print_results_containing_None_without_error(self, capsys): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [ {"id": "1", "score": 0.1}, @@ -369,7 +414,7 @@ def test_can_print_results_containing_None_without_error(self, capsys): capsys.readouterr() def test_can_print_results_containing_short_vectors(self, capsys): - aggregator = QueryResultsAggregator(top_k=4) + aggregator = QueryResultsAggregator(top_k=4, metric="euclidean") results1 = { "matches": [ {"id": "1", "score": 0.1, "values": [0.31]}, @@ -386,7 +431,7 @@ def test_can_print_results_containing_short_vectors(self, capsys): capsys.readouterr() def test_can_print_complete_results_without_error(self, capsys): - aggregator = QueryResultsAggregator(top_k=2) + aggregator = QueryResultsAggregator(top_k=2, metric="euclidean") results1 = { "matches": [ { @@ -421,29 +466,45 @@ def test_can_print_complete_results_without_error(self, capsys): class TestQueryAggregatorEdgeCases: def test_topK_too_small(self): with pytest.raises(QueryResultsAggregatorInvalidTopKError): - QueryResultsAggregator(top_k=0) - with pytest.raises(QueryResultsAggregatorInvalidTopKError): - QueryResultsAggregator(top_k=1) + QueryResultsAggregator(top_k=0, metric="euclidean") + + def test_results_never_added(self): + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 0 + assert len(results.matches) == 0 - def test_matches_too_small(self): - aggregator = QueryResultsAggregator(top_k=3) + @pytest.mark.parametrize("index_metric", ["euclidean", "cosine", "dotproduct"]) + def test_tie_breaking(self, index_metric): + aggregator = QueryResultsAggregator(top_k=3, metric=index_metric) results1 = { - "matches": [{"id": "1", "score": 0.1}], + "matches": [{"id": "5", "score": 0.9}, {"id": "2", "score": 0.9}], "usage": {"readUnits": 5}, "namespace": "ns1", } - with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): - aggregator.add_results(results1) - def test_empty_results(self): - aggregator = QueryResultsAggregator(top_k=3) + results2 = { + "matches": [{"id": "3", "score": 0.9}, {"id": "4", "score": 0.9}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results1) + aggregator.add_results(results2) results = aggregator.get_results() - assert results is not None - assert results.usage.read_units == 0 - assert len(results.matches) == 0 + assert results.usage.read_units == 10 + assert len(results.matches) == 3 + + # Maintains order results were added when resolving ties + assert results.matches[0].id == "5" + assert results.matches[0].namespace == "ns1" + assert results.matches[1].id == "2" + assert results.matches[1].namespace == "ns1" + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns2" def test_empty_results_with_usage(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) @@ -455,7 +516,7 @@ def test_empty_results_with_usage(self): assert len(results.matches) == 0 def test_exactly_one_result(self): - aggregator = QueryResultsAggregator(top_k=3) + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], "usage": {"readUnits": 5}, @@ -482,56 +543,32 @@ def test_exactly_one_result(self): assert results.matches[2].namespace == "ns2" assert results.matches[2].score == 0.2 - def test_two_result_sets_with_single_result_errors(self): - with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): - aggregator = QueryResultsAggregator(top_k=3) - results1 = { - "matches": [{"id": "1", "score": 0.1}], - "usage": {"readUnits": 5}, - "namespace": "ns1", - } - aggregator.add_results(results1) - results2 = { - "matches": [{"id": "2", "score": 0.01}], - "usage": {"readUnits": 5}, - "namespace": "ns2", - } - aggregator.add_results(results2) - - def test_single_result_after_index_type_known_no_error(self): - aggregator = QueryResultsAggregator(top_k=3) - - results3 = { - "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], - "usage": {"readUnits": 5}, - "namespace": "ns3", - } - aggregator.add_results(results3) - + def test_two_result_sets_with_single_result(self): + aggregator = QueryResultsAggregator(top_k=3, metric="euclidean") results1 = { "matches": [{"id": "1", "score": 0.1}], "usage": {"readUnits": 5}, "namespace": "ns1", } aggregator.add_results(results1) - results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"} + results2 = { + "matches": [{"id": "2", "score": 0.01}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } aggregator.add_results(results2) - results = aggregator.get_results() - assert results.usage.read_units == 15 - assert len(results.matches) == 3 + assert results.usage.read_units == 10 + assert len(results.matches) == 2 assert results.matches[0].id == "2" - assert results.matches[0].namespace == "ns3" + assert results.matches[0].namespace == "ns2" assert results.matches[0].score == 0.01 assert results.matches[1].id == "1" assert results.matches[1].namespace == "ns1" assert results.matches[1].score == 0.1 - assert results.matches[2].id == "3" - assert results.matches[2].namespace == "ns3" - assert results.matches[2].score == 0.2 def test_all_empty_results(self): - aggregator = QueryResultsAggregator(top_k=10) + aggregator = QueryResultsAggregator(top_k=10, metric="euclidean") aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) @@ -543,7 +580,7 @@ def test_all_empty_results(self): assert len(results.matches) == 0 def test_some_empty_results(self): - aggregator = QueryResultsAggregator(top_k=10) + aggregator = QueryResultsAggregator(top_k=10, metric="euclidean") results2 = { "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], "usage": {"readUnits": 5},