Skip to content

Commit

Permalink
Add early return logic in results merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 22, 2024
1 parent ae244a2 commit 77a731b
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 2 deletions.
11 changes: 9 additions & 2 deletions pinecone/grpc/query_results_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def _process_matches(self, matches, ns, heap_item_fn):
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
# Assume we have dotproduct scores sorted in descending order
if self.is_dotproduct and match["score"] < self.heap[0][0]:
# No further matches can improve the top-K heap
break
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
# No further matches can improve the top-K heap
break
heapq.heappushpop(self.heap, heap_item_fn(match, ns))

def add_results(self, results: Dict[str, Any]):
Expand All @@ -156,9 +163,9 @@ def add_results(self, results: Dict[str, Any]):
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
self._process_matches(matches, ns, self._dotproduct_heap_item)
self._process_matches2(matches, ns, self._dotproduct_heap_item)
else:
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
self._process_matches2(matches, ns, self._non_dotproduct_heap_item)

def get_results(self) -> QueryNamespacesResults:
if self.read:
Expand Down
141 changes: 141 additions & 0 deletions tests/unit_grpc/test_query_results_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
QueryResultsAggregatorInvalidTopKError,
QueryResultsAggregregatorNotEnoughResultsError,
)
import random
import pytest


Expand Down Expand Up @@ -125,6 +126,146 @@ def test_correctly_handles_dotproduct_metric(self):
assert results.matches[1].id == "1" # 0.9
assert results.matches[2].id == "8" # 0.88

def test_still_correct_with_early_return(self):
aggregator = QueryResultsAggregator(top_k=5)

results1 = {
"matches": [
{"id": "1", "score": 0.1, "values": []},
{"id": "2", "score": 0.11, "values": []},
{"id": "3", "score": 0.12, "values": []},
{"id": "4", "score": 0.13, "values": []},
{"id": "5", "score": 0.14, "values": []},
],
"usage": {"readUnits": 5},
"namespace": "ns1",
}
aggregator.add_results(results1)

results2 = {
"matches": [
{"id": "6", "score": 0.10, "values": []},
{"id": "7", "score": 0.101, "values": []},
{"id": "8", "score": 0.12, "values": []},
{"id": "9", "score": 0.13, "values": []},
{"id": "10", "score": 0.14, "values": []},
],
"usage": {"readUnits": 5},
"namespace": "ns2",
}
aggregator.add_results(results2)

results = aggregator.get_results()
assert results.usage.read_units == 10
assert len(results.matches) == 5
assert results.matches[0].id == "1"
assert results.matches[1].id == "6"
assert results.matches[2].id == "7"
assert results.matches[3].id == "2"
assert results.matches[4].id == "3"

def test_still_correct_with_early_return_generated_nont_dotproduct(self):
aggregator = QueryResultsAggregator(top_k=1000)
matches1 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000)
]
matches1.sort(key=lambda x: x["score"], reverse=False)

matches2 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000)
]
matches2.sort(key=lambda x: x["score"], reverse=False)

matches3 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000)
]
matches3.sort(key=lambda x: x["score"], reverse=False)

matches4 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000)
]
matches4.sort(key=lambda x: x["score"], reverse=False)

matches5 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000)
]
matches5.sort(key=lambda x: x["score"], reverse=False)

results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}}
results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}}
results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}}
results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}}
results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}}

aggregator.add_results(results1)
aggregator.add_results(results2)
aggregator.add_results(results3)
aggregator.add_results(results4)
aggregator.add_results(results5)

merged_matches = matches1 + matches2 + matches3 + matches4 + matches5
merged_matches.sort(key=lambda x: x["score"], reverse=False)

results = aggregator.get_results()
assert results.usage.read_units == 25
assert len(results.matches) == 1000
assert results.matches[0].score == merged_matches[0]["score"]
assert results.matches[1].score == merged_matches[1]["score"]
assert results.matches[2].score == merged_matches[2]["score"]
assert results.matches[3].score == merged_matches[3]["score"]
assert results.matches[4].score == merged_matches[4]["score"]

def test_still_correct_with_early_return_generated_dotproduct(self):
aggregator = QueryResultsAggregator(top_k=1000)
matches1 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000)
]
matches1.sort(key=lambda x: x["score"], reverse=True)

matches2 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000)
]
matches2.sort(key=lambda x: x["score"], reverse=True)

matches3 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000)
]
matches3.sort(key=lambda x: x["score"], reverse=True)

matches4 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000)
]
matches4.sort(key=lambda x: x["score"], reverse=True)

matches5 = [
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000)
]
matches5.sort(key=lambda x: x["score"], reverse=True)

results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}}
results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}}
results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}}
results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}}
results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}}

aggregator.add_results(results1)
aggregator.add_results(results2)
aggregator.add_results(results3)
aggregator.add_results(results4)
aggregator.add_results(results5)

merged_matches = matches1 + matches2 + matches3 + matches4 + matches5
merged_matches.sort(key=lambda x: x["score"], reverse=True)

results = aggregator.get_results()
assert results.usage.read_units == 25
assert len(results.matches) == 1000
assert results.matches[0].score == merged_matches[0]["score"]
assert results.matches[1].score == merged_matches[1]["score"]
assert results.matches[2].score == merged_matches[2]["score"]
assert results.matches[3].score == merged_matches[3]["score"]
assert results.matches[4].score == merged_matches[4]["score"]


class TestQueryResultsAggregatorOutputUX:
def test_can_interact_with_attributes(self):
Expand Down

0 comments on commit 77a731b

Please sign in to comment.