From e668c89985c468a5ac006ffb40cb3d0b09ec4584 Mon Sep 17 00:00:00 2001 From: Jennifer Hamon Date: Wed, 13 Nov 2024 06:15:31 -0500 Subject: [PATCH] Add `query_namespaces` (#409) ## Problem Sometimes people would like to run a query across multiple namespaces ## Solution Run a query for each namespace in parallel, then merge the results using a heap ```python from pinecone import Pinecone import random pc = Pinecone(api_key='api-key') index = pc.Index( host="https://indexhost/", pool_threads=10 ) query_vec = [random.random()] * dimension combined_results = index.query_namespaces( vector=query_vec, namespaces=["ns1", "ns2", "ns3", "ns4"], include_values=False, include_metadata=True, filter={"publication_date": {"$eq":"Last3Months"}}, top_k=100 ) ``` ## TODO A grpc implementation of this will follow in a separate PR. I have WIP on it, but some mypy type issues were causing me headaches and I'd rather land this stuff first. ## Type of Change - [x] New feature (non-breaking change which adds functionality) ## Test Plan Added integration tests --- pinecone/core/openapi/shared/api_client.py | 63 +- pinecone/data/index.py | 84 ++- pinecone/data/query_results_aggregator.py | 193 ++++++ pinecone/grpc/index_grpc.py | 14 +- tests/integration/data/conftest.py | 5 +- .../integration/data/test_query_namespaces.py | 226 +++++++ tests/unit/test_query_results_aggregator.py | 561 ++++++++++++++++++ 7 files changed, 1120 insertions(+), 26 deletions(-) create mode 100644 pinecone/data/query_results_aggregator.py create mode 100644 tests/integration/data/test_query_namespaces.py create mode 100644 tests/unit/test_query_results_aggregator.py diff --git a/pinecone/core/openapi/shared/api_client.py b/pinecone/core/openapi/shared/api_client.py index dda97ec5..7ec644c5 100644 --- a/pinecone/core/openapi/shared/api_client.py +++ b/pinecone/core/openapi/shared/api_client.py @@ -8,6 +8,24 @@ import typing from urllib.parse import quote from urllib3.fields import RequestField +import time +import random + +def retry_api_call( + func, args=(), kwargs={}, retries=3, backoff=1, jitter=0.5 +): + attempts = 0 + while attempts < retries: + try: + return func(*args, **kwargs) # Attempt to call __call_api + except Exception as e: + attempts += 1 + if attempts >= retries: + print(f"API call failed after {attempts} attempts: {e}") + raise # Re-raise exception if retries are exhausted + sleep_time = backoff * (2 ** (attempts - 1)) + random.uniform(0, jitter) + # print(f"Retrying ({attempts}/{retries}) in {sleep_time:.2f} seconds after error: {e}") + time.sleep(sleep_time) from pinecone.core.openapi.shared import rest @@ -397,25 +415,32 @@ def call_api( ) return self.pool.apply_async( - self.__call_api, - ( - resource_path, - method, - path_params, - query_params, - header_params, - body, - post_params, - files, - response_type, - auth_settings, - _return_http_data_only, - collection_formats, - _preload_content, - _request_timeout, - _host, - _check_type, - ), + retry_api_call, + args=( + self.__call_api, # Pass the API call function as the first argument + ( + resource_path, + method, + path_params, + query_params, + header_params, + body, + post_params, + files, + response_type, + auth_settings, + _return_http_data_only, + collection_formats, + _preload_content, + _request_timeout, + _host, + _check_type, + ), + {}, # empty kwargs dictionary + 3, # retries + 1, # backoff time + 0.5 # jitter + ) ) def request( diff --git a/pinecone/data/index.py b/pinecone/data/index.py index 3f63d046..f2c4c9f9 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -33,6 +33,8 @@ ) from .features.bulk_import import ImportFeatureMixin from .vector_factory import VectorFactory +from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults +from multiprocessing.pool import ApplyResult from pinecone_plugin_interface import load_and_install as install_plugins @@ -387,7 +389,7 @@ def query( Union[SparseValues, Dict[str, Union[List[float], List[int]]]] ] = None, **kwargs, - ) -> QueryResponse: + ) -> Union[QueryResponse, ApplyResult]: """ The Query operation searches a namespace, using a query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. @@ -429,6 +431,39 @@ def query( and namespace name. """ + response = self._query( + *args, + top_k=top_k, + vector=vector, + id=id, + namespace=namespace, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + **kwargs, + ) + + if kwargs.get("async_req", False): + return response + else: + return parse_query_response(response) + + def _query( + self, + *args, + top_k: int, + vector: Optional[List[float]] = None, + id: Optional[str] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[ + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] + ] = None, + **kwargs, + ) -> QueryResponse: if len(args) > 0: raise ValueError( "The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')" @@ -461,7 +496,52 @@ def query( ), **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, ) - return parse_query_response(response) + return response + + @validate_and_convert_errors + def query_namespaces( + self, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[ + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] + ] = None, + **kwargs, + ) -> QueryNamespacesResults: + if namespaces is None or len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + + target_namespaces = set(namespaces) # dedup namespaces + async_results = [ + self.query( + vector=vector, + namespace=ns, + top_k=overall_topk, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + async_req=True, + **kwargs, + ) + for ns in target_namespaces + ] + + for result in async_results: + response = result.get() + aggregator.add_results(response) + + final_results = aggregator.get_results() + return final_results @validate_and_convert_errors def update( diff --git a/pinecone/data/query_results_aggregator.py b/pinecone/data/query_results_aggregator.py new file mode 100644 index 00000000..98ca77a2 --- /dev/null +++ b/pinecone/data/query_results_aggregator.py @@ -0,0 +1,193 @@ +from typing import List, Tuple, Optional, Any, Dict +import json +import heapq +from pinecone.core.openapi.data.models import Usage +from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse + +from dataclasses import dataclass, asdict + + +@dataclass +class ScoredVectorWithNamespace: + namespace: str + score: float + id: str + values: List[float] + sparse_values: dict + metadata: dict + + def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): + json_vector = aggregate_results_heap_tuple[2] + self.namespace = aggregate_results_heap_tuple[3] + self.id = json_vector.get("id") # type: ignore + self.score = json_vector.get("score") # type: ignore + self.values = json_vector.get("values") # type: ignore + self.sparse_values = json_vector.get("sparse_values", None) # type: ignore + self.metadata = json_vector.get("metadata", None) # type: ignore + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace") + + def get(self, key, default=None): + return getattr(self, key, default) + + def __repr__(self): + return json.dumps(self._truncate(asdict(self)), indent=4) + + def __json__(self): + return self._truncate(asdict(self)) + + def _truncate(self, obj, max_items=2): + """ + Recursively traverse and truncate lists that exceed max_items length. + Only display the "... X more" message if at least 2 elements are hidden. + """ + if obj is None: + return None # Skip None values + elif isinstance(obj, list): + filtered_list = [self._truncate(i, max_items) for i in obj if i is not None] + if len(filtered_list) > max_items: + # Show the truncation message only if more than 1 item is hidden + remaining_items = len(filtered_list) - max_items + if remaining_items > 1: + return filtered_list[:max_items] + [f"... {remaining_items} more"] + else: + # If only 1 item remains, show it + return filtered_list + return filtered_list + elif isinstance(obj, dict): + # Recursively process dictionaries, omitting None values + return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None} + return obj + + +@dataclass +class QueryNamespacesResults: + usage: Usage + matches: List[ScoredVectorWithNamespace] + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(f"'{key}' not found in QueryNamespacesResults") + + def get(self, key, default=None): + return getattr(self, key, default) + + def __repr__(self): + return json.dumps( + { + "usage": self.usage.to_dict(), + "matches": [match.__json__() for match in self.matches], + }, + indent=4, + ) + + +class QueryResultsAggregregatorNotEnoughResultsError(Exception): + def __init__(self): + super().__init__( + "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." + ) + + +class QueryResultsAggregatorInvalidTopKError(Exception): + def __init__(self, top_k: int): + super().__init__( + f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." + ) + + +class QueryResultsAggregator: + def __init__(self, top_k: int): + if top_k < 2: + raise QueryResultsAggregatorInvalidTopKError(top_k) + self.top_k = top_k + self.usage_read_units = 0 + self.heap: List[Tuple[float, int, object, str]] = [] + self.insertion_counter = 0 + self.is_dotproduct = None + self.read = False + self.final_results: Optional[QueryNamespacesResults] = None + + def _is_dotproduct_index(self, matches): + # The interpretation of the score depends on the similar metric used. + # Unlike other index types, in indexes configured for dotproduct, + # a higher score is better. We have to infer this is the case by inspecting + # the order of the scores in the results. + for i in range(1, len(matches)): + if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase + return False + return True + + def _dotproduct_heap_item(self, match, ns): + return (match.get("score"), -self.insertion_counter, match, ns) + + def _non_dotproduct_heap_item(self, match, ns): + return (-match.get("score"), -self.insertion_counter, match, ns) + + def _process_matches(self, matches, ns, heap_item_fn): + for match in matches: + self.insertion_counter += 1 + if len(self.heap) < self.top_k: + heapq.heappush(self.heap, heap_item_fn(match, ns)) + else: + # Assume we have dotproduct scores sorted in descending order + if self.is_dotproduct and match["score"] < self.heap[0][0]: + # No further matches can improve the top-K heap + break + elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: + # No further matches can improve the top-K heap + break + heapq.heappushpop(self.heap, heap_item_fn(match, ns)) + + def add_results(self, results: Dict[str, Any]): + if self.read: + # This is mainly just to sanity check in test cases which get quite confusing + # if you read results twice due to the heap being emptied when constructing + # the ordered results. + raise ValueError("Results have already been read. Cannot add more results.") + + matches = results.get("matches", []) + ns: str = results.get("namespace", "") + if isinstance(results, OpenAPIQueryResponse): + self.usage_read_units += results.usage.read_units + else: + self.usage_read_units += results.get("usage", {}).get("readUnits", 0) + + if len(matches) == 0: + return + + if self.is_dotproduct is None: + if len(matches) == 1: + # This condition should match the second time we add results containing + # only one match. We need at least two matches in a single response in order + # to infer the similarity metric + raise QueryResultsAggregregatorNotEnoughResultsError() + self.is_dotproduct = self._is_dotproduct_index(matches) + + if self.is_dotproduct: + self._process_matches(matches, ns, self._dotproduct_heap_item) + else: + self._process_matches(matches, ns, self._non_dotproduct_heap_item) + + def get_results(self) -> QueryNamespacesResults: + if self.read: + if self.final_results is not None: + return self.final_results + else: + # I don't think this branch can ever actually be reached, but the type checker disagrees + raise ValueError("Results have already been read. Cannot get results again.") + self.read = True + + self.final_results = QueryNamespacesResults( + usage=Usage(read_units=self.usage_read_units), + matches=[ + ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) + ][::-1], + ) + return self.final_results diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 317a0fc5..eba611b7 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -326,8 +326,9 @@ def query( include_values: Optional[bool] = None, include_metadata: Optional[bool] = None, sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + async_req: Optional[bool] = False, **kwargs, - ) -> QueryResponse: + ) -> Union[QueryResponse, PineconeGrpcFuture]: """ The Query operation searches a namespace, using a query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. @@ -392,9 +393,14 @@ def query( request = QueryRequest(**args_dict) timeout = kwargs.pop("timeout", None) - response = self.runner.run(self.stub.Query, request, timeout=timeout) - json_response = json_format.MessageToDict(response) - return parse_query_response(json_response, _check_type=False) + + if async_req: + future = self.runner.run(self.stub.Query.future, request, timeout=timeout) + return PineconeGrpcFuture(future) + else: + response = self.runner.run(self.stub.Query, request, timeout=timeout) + json_response = json_format.MessageToDict(response) + return parse_query_response(json_response, _check_type=False) def update( self, diff --git a/tests/integration/data/conftest.py b/tests/integration/data/conftest.py index 828a6d4f..00a43747 100644 --- a/tests/integration/data/conftest.py +++ b/tests/integration/data/conftest.py @@ -52,7 +52,10 @@ def metric(): @pytest.fixture(scope="session") def spec(): - return json.loads(get_environment_var("SPEC")) + spec_json = get_environment_var( + "SPEC", '{"serverless": {"cloud": "aws", "region": "us-east-1" }}' + ) + return json.loads(spec_json) @pytest.fixture(scope="session") diff --git a/tests/integration/data/test_query_namespaces.py b/tests/integration/data/test_query_namespaces.py new file mode 100644 index 00000000..e52c58b0 --- /dev/null +++ b/tests/integration/data/test_query_namespaces.py @@ -0,0 +1,226 @@ +import pytest +import os +from ..helpers import random_string, poll_stats_for_namespace +from pinecone.data.query_results_aggregator import ( + QueryResultsAggregatorInvalidTopKError, + QueryResultsAggregregatorNotEnoughResultsError, +) + +from pinecone import Vector + + +@pytest.mark.skipif( + os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest" +) +class TestQueryNamespacesRest: + def test_query_namespaces(self, idx): + ns_prefix = random_string(5) + ns1 = f"{ns_prefix}-ns1" + ns2 = f"{ns_prefix}-ns2" + ns3 = f"{ns_prefix}-ns3" + + idx.upsert( + vectors=[ + Vector(id="id1", values=[0.1, 0.2], metadata={"genre": "drama", "key": 1}), + Vector(id="id2", values=[0.2, 0.3], metadata={"genre": "drama", "key": 2}), + Vector(id="id3", values=[0.4, 0.5], metadata={"genre": "action", "key": 3}), + Vector(id="id4", values=[0.6, 0.7], metadata={"genre": "action", "key": 4}), + ], + namespace=ns1, + ) + idx.upsert( + vectors=[ + Vector(id="id5", values=[0.21, 0.22], metadata={"genre": "drama", "key": 1}), + Vector(id="id6", values=[0.22, 0.23], metadata={"genre": "drama", "key": 2}), + Vector(id="id7", values=[0.24, 0.25], metadata={"genre": "action", "key": 3}), + Vector(id="id8", values=[0.26, 0.27], metadata={"genre": "action", "key": 4}), + ], + namespace=ns2, + ) + idx.upsert( + vectors=[ + Vector(id="id9", values=[0.31, 0.32], metadata={"genre": "drama", "key": 1}), + Vector(id="id10", values=[0.32, 0.33], metadata={"genre": "drama", "key": 2}), + Vector(id="id11", values=[0.34, 0.35], metadata={"genre": "action", "key": 3}), + Vector(id="id12", values=[0.36, 0.37], metadata={"genre": "action", "key": 4}), + ], + namespace=ns3, + ) + + poll_stats_for_namespace(idx, namespace=ns1, expected_count=4) + poll_stats_for_namespace(idx, namespace=ns2, expected_count=4) + poll_stats_for_namespace(idx, namespace=ns3, expected_count=4) + + results = idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ns1, ns2, ns3], + include_values=True, + include_metadata=True, + filter={"genre": {"$eq": "drama"}}, + top_k=100, + ) + assert len(results.matches) == 6 + assert results.usage.read_units > 0 + for item in results.matches: + assert item.metadata["genre"] == "drama" + assert results.matches[0].id == "id1" + assert results.matches[0].namespace == ns1 + + # Using dot-style accessors + assert results.matches[0].metadata["genre"] == "drama" + assert results.matches[0].metadata["key"] == 1 + + # Using dictionary-style accessors + assert results.matches[0]["metadata"]["genre"] == "drama" + assert results.matches[0]["metadata"]["key"] == 1 + + # Using .get() accessors + assert results.get("matches", [])[0].get("metadata", {}).get("genre") == "drama" + assert results.matches[0].get("metadata", {}) == {"genre": "drama", "key": 1} + assert results.matches[0].get("metadata", {}).get("genre") == "drama" + + assert results.matches[1].id == "id2" + assert results.matches[1].namespace == ns1 + assert results.matches[2].id == "id5" + assert results.matches[2].namespace == ns2 + + # Non-existent namespace shouldn't cause any problem + results2 = idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ns1, ns2, ns3, f"{ns_prefix}-nonexistent"], + include_values=True, + include_metadata=True, + filter={"genre": {"$eq": "action"}}, + top_k=100, + ) + assert len(results2.matches) == 6 + assert results2.usage.read_units > 0 + for item in results2.matches: + assert item.metadata["genre"] == "action" + + # Test with empty filter, top_k greater than number of results + results3 = idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ns1, ns2, ns3], + include_values=True, + include_metadata=True, + filter={}, + top_k=100, + ) + assert len(results3.matches) == 12 + assert results3.usage.read_units > 0 + + # Test when all results are filtered out + results4 = idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ns1, ns2, ns3], + include_values=True, + include_metadata=True, + filter={"genre": {"$eq": "comedy"}}, + top_k=100, + ) + assert len(results4.matches) == 0 + assert results4.usage.read_units > 0 + + # Test with top_k less than number of results + results5 = idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ns1, ns2, ns3], + include_values=True, + include_metadata=True, + filter={}, + top_k=2, + ) + assert len(results5.matches) == 2 + + # Test when all namespaces are non-existent (same as all results filtered / empty) + results6 = idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ + f"{ns_prefix}-nonexistent1", + f"{ns_prefix}-nonexistent2", + f"{ns_prefix}-nonexistent3", + ], + include_values=True, + include_metadata=True, + filter={"genre": {"$eq": "comedy"}}, + top_k=2, + ) + assert len(results6.matches) == 0 + assert results6.usage.read_units > 0 + + def test_invalid_top_k(self, idx): + with pytest.raises(QueryResultsAggregatorInvalidTopKError) as e: + idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=["ns1", "ns2", "ns3"], + include_values=True, + include_metadata=True, + filter={}, + top_k=1, + ) + assert ( + str(e.value) + == "Invalid top_k value 1. To aggregate results from multiple queries the top_k must be at least 2." + ) + + def test_unmergeable_results(self, idx): + ns_prefix = random_string(5) + ns1 = f"{ns_prefix}-ns1" + ns2 = f"{ns_prefix}-ns2" + + idx.upsert( + vectors=[ + Vector(id="id1", values=[0.1, 0.2], metadata={"genre": "drama", "key": 1}), + Vector(id="id2", values=[0.2, 0.3], metadata={"genre": "drama", "key": 2}), + ], + namespace=ns1, + ) + idx.upsert( + vectors=[ + Vector(id="id5", values=[0.21, 0.22], metadata={"genre": "drama", "key": 1}), + Vector(id="id6", values=[0.22, 0.23], metadata={"genre": "drama", "key": 2}), + ], + namespace=ns2, + ) + + poll_stats_for_namespace(idx, namespace=ns1, expected_count=2) + poll_stats_for_namespace(idx, namespace=ns2, expected_count=2) + + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError) as e: + idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[ns1, ns2], + include_values=True, + include_metadata=True, + filter={"key": {"$eq": 1}}, + top_k=2, + ) + + assert ( + str(e.value) + == "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." + ) + + def test_missing_namespaces(self, idx): + with pytest.raises(ValueError) as e: + idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=[], + include_values=True, + include_metadata=True, + filter={}, + top_k=2, + ) + assert str(e.value) == "At least one namespace must be specified" + + with pytest.raises(ValueError) as e: + idx.query_namespaces( + vector=[0.1, 0.2], + namespaces=None, + include_values=True, + include_metadata=True, + filter={}, + top_k=2, + ) + assert str(e.value) == "At least one namespace must be specified" diff --git a/tests/unit/test_query_results_aggregator.py b/tests/unit/test_query_results_aggregator.py new file mode 100644 index 00000000..c482ca15 --- /dev/null +++ b/tests/unit/test_query_results_aggregator.py @@ -0,0 +1,561 @@ +from pinecone.data.query_results_aggregator import ( + QueryResultsAggregator, + QueryResultsAggregatorInvalidTopKError, + QueryResultsAggregregatorNotEnoughResultsError, +) +import random +import pytest + + +class TestQueryResultsAggregator: + def test_keeps_running_usage_total(self): + aggregator = QueryResultsAggregator(top_k=3) + + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "3", "score": 0.12, "values": []}, + {"id": "4", "score": 0.13, "values": []}, + {"id": "5", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "7", "score": 0.101, "values": []}, + {"id": "8", "score": 0.111, "values": []}, + {"id": "9", "score": 0.12, "values": []}, + {"id": "10", "score": 0.13, "values": []}, + {"id": "11", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "1" # 0.1 + assert results.matches[1].id == "7" # 0.101 + assert results.matches[2].id == "2" # 0.11 + + # Bracket-style accessor + assert results["usage"]["read_units"] == results.usage.read_units + assert results["matches"][0]["id"] == results.matches[0].id + + # Get-style accessor + assert results.get("matches", []) == results.matches + assert results.get("usage", {}).get("read_units") == results.usage.read_units + + def test_inserting_duplicate_scores_stable_ordering(self): + aggregator = QueryResultsAggregator(top_k=5) + + results1 = { + "matches": [ + {"id": "1", "score": 0.11, "values": []}, + {"id": "3", "score": 0.11, "values": []}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "4", "score": 0.22, "values": []}, + {"id": "5", "score": 0.22, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.11, "values": []}, + {"id": "7", "score": 0.22, "values": []}, + {"id": "8", "score": 0.22, "values": []}, + {"id": "9", "score": 0.22, "values": []}, + {"id": "10", "score": 0.22, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 5 + assert results.matches[0].id == "1" # 0.11 + assert results.matches[0].namespace == "ns1" + assert results.matches[1].id == "3" # 0.11 + assert results.matches[1].namespace == "ns1" + assert results.matches[2].id == "2" # 0.11 + assert results.matches[2].namespace == "ns1" + assert results.matches[3].id == "6" # 0.11 + assert results.matches[3].namespace == "ns2" + assert results.matches[4].id == "4" # 0.22 + assert results.matches[4].namespace == "ns1" + + def test_correctly_handles_dotproduct_metric(self): + # For this index metric, the higher the score, the more similar the vectors are. + # We have to infer that we have this type of index by seeing whether scores are + # sorted in descending or ascending order. + aggregator = QueryResultsAggregator(top_k=3) + + desc_results1 = { + "matches": [ + {"id": "1", "score": 0.9, "values": []}, + {"id": "2", "score": 0.8, "values": []}, + {"id": "3", "score": 0.7, "values": []}, + {"id": "4", "score": 0.6, "values": []}, + {"id": "5", "score": 0.5, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(desc_results1) + + results2 = { + "matches": [ + {"id": "7", "score": 0.77, "values": []}, + {"id": "8", "score": 0.88, "values": []}, + {"id": "9", "score": 0.99, "values": []}, + {"id": "10", "score": 0.1010, "values": []}, + {"id": "11", "score": 0.1111, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "9" # 0.99 + assert results.matches[1].id == "1" # 0.9 + assert results.matches[2].id == "8" # 0.88 + + def test_still_correct_with_early_return(self): + aggregator = QueryResultsAggregator(top_k=5) + + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": []}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "3", "score": 0.12, "values": []}, + {"id": "4", "score": 0.13, "values": []}, + {"id": "5", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.10, "values": []}, + {"id": "7", "score": 0.101, "values": []}, + {"id": "8", "score": 0.12, "values": []}, + {"id": "9", "score": 0.13, "values": []}, + {"id": "10", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 5 + assert results.matches[0].id == "1" + assert results.matches[1].id == "6" + assert results.matches[2].id == "7" + assert results.matches[3].id == "2" + assert results.matches[4].id == "3" + + def test_still_correct_with_early_return_generated_nont_dotproduct(self): + aggregator = QueryResultsAggregator(top_k=1000) + matches1 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) + ] + matches1.sort(key=lambda x: x["score"], reverse=False) + + matches2 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000) + ] + matches2.sort(key=lambda x: x["score"], reverse=False) + + matches3 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000) + ] + matches3.sort(key=lambda x: x["score"], reverse=False) + + matches4 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000) + ] + matches4.sort(key=lambda x: x["score"], reverse=False) + + matches5 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000) + ] + matches5.sort(key=lambda x: x["score"], reverse=False) + + results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}} + results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}} + results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}} + results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}} + results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}} + + aggregator.add_results(results1) + aggregator.add_results(results2) + aggregator.add_results(results3) + aggregator.add_results(results4) + aggregator.add_results(results5) + + merged_matches = matches1 + matches2 + matches3 + matches4 + matches5 + merged_matches.sort(key=lambda x: x["score"], reverse=False) + + results = aggregator.get_results() + assert results.usage.read_units == 25 + assert len(results.matches) == 1000 + assert results.matches[0].score == merged_matches[0]["score"] + assert results.matches[1].score == merged_matches[1]["score"] + assert results.matches[2].score == merged_matches[2]["score"] + assert results.matches[3].score == merged_matches[3]["score"] + assert results.matches[4].score == merged_matches[4]["score"] + + def test_still_correct_with_early_return_generated_dotproduct(self): + aggregator = QueryResultsAggregator(top_k=1000) + matches1 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) + ] + matches1.sort(key=lambda x: x["score"], reverse=True) + + matches2 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000) + ] + matches2.sort(key=lambda x: x["score"], reverse=True) + + matches3 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000) + ] + matches3.sort(key=lambda x: x["score"], reverse=True) + + matches4 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000) + ] + matches4.sort(key=lambda x: x["score"], reverse=True) + + matches5 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000) + ] + matches5.sort(key=lambda x: x["score"], reverse=True) + + results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}} + results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}} + results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}} + results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}} + results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}} + + aggregator.add_results(results1) + aggregator.add_results(results2) + aggregator.add_results(results3) + aggregator.add_results(results4) + aggregator.add_results(results5) + + merged_matches = matches1 + matches2 + matches3 + matches4 + matches5 + merged_matches.sort(key=lambda x: x["score"], reverse=True) + + results = aggregator.get_results() + assert results.usage.read_units == 25 + assert len(results.matches) == 1000 + assert results.matches[0].score == merged_matches[0]["score"] + assert results.matches[1].score == merged_matches[1]["score"] + assert results.matches[2].score == merged_matches[2]["score"] + assert results.matches[3].score == merged_matches[3]["score"] + assert results.matches[4].score == merged_matches[4]["score"] + + +class TestQueryResultsAggregatorOutputUX: + def test_can_interact_with_attributes(self): + aggregator = QueryResultsAggregator(top_k=2) + results1 = { + "matches": [ + { + "id": "1", + "score": 0.3, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + {"id": "2", "score": 0.4}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + assert results.usage.read_units == 5 + assert results.matches[0].id == "1" + assert results.matches[0].namespace == "ns1" + assert results.matches[0].score == 0.3 + assert results.matches[0].values == [0.31, 0.32, 0.33, 0.34, 0.35, 0.36] + + def test_can_interact_like_dict(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [ + { + "id": "1", + "score": 0.3, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + { + "id": "2", + "score": 0.4, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + assert results["usage"]["read_units"] == 5 + assert results["matches"][0]["id"] == "1" + assert results["matches"][0]["namespace"] == "ns1" + assert results["matches"][0]["score"] == 0.3 + + def test_can_print_empty_results_without_error(self, capsys): + aggregator = QueryResultsAggregator(top_k=3) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + def test_can_print_results_containing_None_without_error(self, capsys): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [ + {"id": "1", "score": 0.1}, + {"id": "2", "score": 0.11}, + {"id": "3", "score": 0.12}, + {"id": "4", "score": 0.13}, + {"id": "5", "score": 0.14}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + def test_can_print_results_containing_short_vectors(self, capsys): + aggregator = QueryResultsAggregator(top_k=4) + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": [0.31]}, + {"id": "2", "score": 0.11, "values": [0.31, 0.32]}, + {"id": "3", "score": 0.12, "values": [0.31, 0.32, 0.33]}, + {"id": "3", "score": 0.12, "values": [0.31, 0.32, 0.33, 0.34]}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + def test_can_print_complete_results_without_error(self, capsys): + aggregator = QueryResultsAggregator(top_k=2) + results1 = { + "matches": [ + { + "id": "1", + "score": 0.3, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + { + "id": "2", + "score": 0.4, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": {"boolean": True, "nullish": None}, + }, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + +class TestQueryAggregatorEdgeCases: + def test_topK_too_small(self): + with pytest.raises(QueryResultsAggregatorInvalidTopKError): + QueryResultsAggregator(top_k=0) + with pytest.raises(QueryResultsAggregatorInvalidTopKError): + QueryResultsAggregator(top_k=1) + + def test_matches_too_small(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): + aggregator.add_results(results1) + + def test_empty_results(self): + aggregator = QueryResultsAggregator(top_k=3) + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 0 + assert len(results.matches) == 0 + + def test_empty_results_with_usage(self): + aggregator = QueryResultsAggregator(top_k=3) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 15 + assert len(results.matches) == 0 + + def test_exactly_one_result(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results1) + + results2 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results2) + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 3 + assert results.matches[0].id == "2" + assert results.matches[0].namespace == "ns2" + assert results.matches[0].score == 0.01 + assert results.matches[1].id == "1" + assert results.matches[1].namespace == "ns1" + assert results.matches[1].score == 0.1 + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns2" + assert results.matches[2].score == 0.2 + + def test_two_result_sets_with_single_result_errors(self): + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results2 = { + "matches": [{"id": "2", "score": 0.01}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + def test_single_result_after_index_type_known_no_error(self): + aggregator = QueryResultsAggregator(top_k=3) + + results3 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns3", + } + aggregator.add_results(results3) + + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"} + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 15 + assert len(results.matches) == 3 + assert results.matches[0].id == "2" + assert results.matches[0].namespace == "ns3" + assert results.matches[0].score == 0.01 + assert results.matches[1].id == "1" + assert results.matches[1].namespace == "ns1" + assert results.matches[1].score == 0.1 + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns3" + assert results.matches[2].score == 0.2 + + def test_all_empty_results(self): + aggregator = QueryResultsAggregator(top_k=10) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + + assert results.usage.read_units == 15 + assert len(results.matches) == 0 + + def test_some_empty_results(self): + aggregator = QueryResultsAggregator(top_k=10) + results2 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns0", + } + aggregator.add_results(results2) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + + assert results.usage.read_units == 20 + assert len(results.matches) == 2