From 77a731b2584cde39c51d437d70f7286706cdc32c Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Tue, 22 Oct 2024 11:23:07 -0400 Subject: [PATCH] Add early return logic in results merge --- pinecone/grpc/query_results_aggregator.py | 11 +- .../test_query_results_aggregator.py | 141 ++++++++++++++++++ 2 files changed, 150 insertions(+), 2 deletions(-) diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/grpc/query_results_aggregator.py index 448fae48..b01a1200 100644 --- a/pinecone/grpc/query_results_aggregator.py +++ b/pinecone/grpc/query_results_aggregator.py @@ -134,6 +134,13 @@ def _process_matches(self, matches, ns, heap_item_fn): if len(self.heap) < self.top_k: heapq.heappush(self.heap, heap_item_fn(match, ns)) else: + # Assume we have dotproduct scores sorted in descending order + if self.is_dotproduct and match["score"] < self.heap[0][0]: + # No further matches can improve the top-K heap + break + elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: + # No further matches can improve the top-K heap + break heapq.heappushpop(self.heap, heap_item_fn(match, ns)) def add_results(self, results: Dict[str, Any]): @@ -156,9 +163,9 @@ def add_results(self, results: Dict[str, Any]): self.is_dotproduct = self._is_dotproduct_index(matches) if self.is_dotproduct: - self._process_matches(matches, ns, self._dotproduct_heap_item) + self._process_matches2(matches, ns, self._dotproduct_heap_item) else: - self._process_matches(matches, ns, self._non_dotproduct_heap_item) + self._process_matches2(matches, ns, self._non_dotproduct_heap_item) def get_results(self) -> QueryNamespacesResults: if self.read: diff --git a/tests/unit_grpc/test_query_results_aggregator.py b/tests/unit_grpc/test_query_results_aggregator.py index d4213ae6..e78d255b 100644 --- a/tests/unit_grpc/test_query_results_aggregator.py +++ b/tests/unit_grpc/test_query_results_aggregator.py @@ -3,6 +3,7 @@ QueryResultsAggregatorInvalidTopKError, QueryResultsAggregregatorNotEnoughResultsError, ) +import random import pytest @@ -125,6 +126,146 @@ def test_correctly_handles_dotproduct_metric(self): assert results.matches[1].id == "1" # 0.9 assert results.matches[2].id == "8" # 0.88 + def test_still_correct_with_early_return(self): + aggregator = QueryResultsAggregator(top_k=5) + + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": []}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "3", "score": 0.12, "values": []}, + {"id": "4", "score": 0.13, "values": []}, + {"id": "5", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.10, "values": []}, + {"id": "7", "score": 0.101, "values": []}, + {"id": "8", "score": 0.12, "values": []}, + {"id": "9", "score": 0.13, "values": []}, + {"id": "10", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 5 + assert results.matches[0].id == "1" + assert results.matches[1].id == "6" + assert results.matches[2].id == "7" + assert results.matches[3].id == "2" + assert results.matches[4].id == "3" + + def test_still_correct_with_early_return_generated_nont_dotproduct(self): + aggregator = QueryResultsAggregator(top_k=1000) + matches1 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) + ] + matches1.sort(key=lambda x: x["score"], reverse=False) + + matches2 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000) + ] + matches2.sort(key=lambda x: x["score"], reverse=False) + + matches3 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000) + ] + matches3.sort(key=lambda x: x["score"], reverse=False) + + matches4 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000) + ] + matches4.sort(key=lambda x: x["score"], reverse=False) + + matches5 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000) + ] + matches5.sort(key=lambda x: x["score"], reverse=False) + + results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}} + results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}} + results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}} + results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}} + results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}} + + aggregator.add_results(results1) + aggregator.add_results(results2) + aggregator.add_results(results3) + aggregator.add_results(results4) + aggregator.add_results(results5) + + merged_matches = matches1 + matches2 + matches3 + matches4 + matches5 + merged_matches.sort(key=lambda x: x["score"], reverse=False) + + results = aggregator.get_results() + assert results.usage.read_units == 25 + assert len(results.matches) == 1000 + assert results.matches[0].score == merged_matches[0]["score"] + assert results.matches[1].score == merged_matches[1]["score"] + assert results.matches[2].score == merged_matches[2]["score"] + assert results.matches[3].score == merged_matches[3]["score"] + assert results.matches[4].score == merged_matches[4]["score"] + + def test_still_correct_with_early_return_generated_dotproduct(self): + aggregator = QueryResultsAggregator(top_k=1000) + matches1 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) + ] + matches1.sort(key=lambda x: x["score"], reverse=True) + + matches2 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000) + ] + matches2.sort(key=lambda x: x["score"], reverse=True) + + matches3 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000) + ] + matches3.sort(key=lambda x: x["score"], reverse=True) + + matches4 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000) + ] + matches4.sort(key=lambda x: x["score"], reverse=True) + + matches5 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000) + ] + matches5.sort(key=lambda x: x["score"], reverse=True) + + results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}} + results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}} + results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}} + results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}} + results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}} + + aggregator.add_results(results1) + aggregator.add_results(results2) + aggregator.add_results(results3) + aggregator.add_results(results4) + aggregator.add_results(results5) + + merged_matches = matches1 + matches2 + matches3 + matches4 + matches5 + merged_matches.sort(key=lambda x: x["score"], reverse=True) + + results = aggregator.get_results() + assert results.usage.read_units == 25 + assert len(results.matches) == 1000 + assert results.matches[0].score == merged_matches[0]["score"] + assert results.matches[1].score == merged_matches[1]["score"] + assert results.matches[2].score == merged_matches[2]["score"] + assert results.matches[3].score == merged_matches[3]["score"] + assert results.matches[4].score == merged_matches[4]["score"] + class TestQueryResultsAggregatorOutputUX: def test_can_interact_with_attributes(self):