Skip to content

Commit

Permalink
Add query_namespaces (#409)
Browse files Browse the repository at this point in the history
## Problem

Sometimes people would like to run a query across multiple namespaces

## Solution

Run a query for each namespace in parallel, then merge the results using
a heap

```python
from pinecone import Pinecone
import random

pc = Pinecone(api_key='api-key')

index = pc.Index(
    host="https://indexhost/",
    pool_threads=10
)

query_vec = [random.random()] * dimension

combined_results = index.query_namespaces(
    vector=query_vec,
    namespaces=["ns1", "ns2", "ns3", "ns4"],
    include_values=False,
    include_metadata=True,
    filter={"publication_date": {"$eq":"Last3Months"}},
    top_k=100
)
```

## TODO

A grpc implementation of this will follow in a separate PR. I have WIP
on it, but some mypy type issues were causing me headaches and I'd
rather land this stuff first.

## Type of Change

- [x] New feature (non-breaking change which adds functionality)

## Test Plan

Added integration tests
  • Loading branch information
jhamon authored Nov 13, 2024
1 parent b4bfae8 commit e668c89
Show file tree
Hide file tree
Showing 7 changed files with 1,120 additions and 26 deletions.
63 changes: 44 additions & 19 deletions pinecone/core/openapi/shared/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
import typing
from urllib.parse import quote
from urllib3.fields import RequestField
import time
import random

def retry_api_call(
func, args=(), kwargs={}, retries=3, backoff=1, jitter=0.5
):
attempts = 0
while attempts < retries:
try:
return func(*args, **kwargs) # Attempt to call __call_api
except Exception as e:
attempts += 1
if attempts >= retries:
print(f"API call failed after {attempts} attempts: {e}")
raise # Re-raise exception if retries are exhausted
sleep_time = backoff * (2 ** (attempts - 1)) + random.uniform(0, jitter)
# print(f"Retrying ({attempts}/{retries}) in {sleep_time:.2f} seconds after error: {e}")
time.sleep(sleep_time)


from pinecone.core.openapi.shared import rest
Expand Down Expand Up @@ -397,25 +415,32 @@ def call_api(
)

return self.pool.apply_async(
self.__call_api,
(
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
),
retry_api_call,
args=(
self.__call_api, # Pass the API call function as the first argument
(
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
),
{}, # empty kwargs dictionary
3, # retries
1, # backoff time
0.5 # jitter
)
)

def request(
Expand Down
84 changes: 82 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)
from .features.bulk_import import ImportFeatureMixin
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from multiprocessing.pool import ApplyResult

from pinecone_plugin_interface import load_and_install as install_plugins

Expand Down Expand Up @@ -387,7 +389,7 @@ def query(
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
) -> Union[QueryResponse, ApplyResult]:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
Expand Down Expand Up @@ -429,6 +431,39 @@ def query(
and namespace name.
"""

response = self._query(
*args,
top_k=top_k,
vector=vector,
id=id,
namespace=namespace,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
**kwargs,
)

if kwargs.get("async_req", False):
return response
else:
return parse_query_response(response)

def _query(
self,
*args,
top_k: int,
vector: Optional[List[float]] = None,
id: Optional[str] = None,
namespace: Optional[str] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
if len(args) > 0:
raise ValueError(
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
Expand Down Expand Up @@ -461,7 +496,52 @@ def query(
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
)
return parse_query_response(response)
return response

@validate_and_convert_errors
def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryNamespacesResults:
if namespaces is None or len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)

target_namespaces = set(namespaces) # dedup namespaces
async_results = [
self.query(
vector=vector,
namespace=ns,
top_k=overall_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=True,
**kwargs,
)
for ns in target_namespaces
]

for result in async_results:
response = result.get()
aggregator.add_results(response)

final_results = aggregator.get_results()
return final_results

@validate_and_convert_errors
def update(
Expand Down
193 changes: 193 additions & 0 deletions pinecone/data/query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import List, Tuple, Optional, Any, Dict
import json
import heapq
from pinecone.core.openapi.data.models import Usage
from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse

from dataclasses import dataclass, asdict


@dataclass
class ScoredVectorWithNamespace:
namespace: str
score: float
id: str
values: List[float]
sparse_values: dict
metadata: dict

def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
json_vector = aggregate_results_heap_tuple[2]
self.namespace = aggregate_results_heap_tuple[3]
self.id = json_vector.get("id") # type: ignore
self.score = json_vector.get("score") # type: ignore
self.values = json_vector.get("values") # type: ignore
self.sparse_values = json_vector.get("sparse_values", None) # type: ignore
self.metadata = json_vector.get("metadata", None) # type: ignore

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")

def get(self, key, default=None):
return getattr(self, key, default)

def __repr__(self):
return json.dumps(self._truncate(asdict(self)), indent=4)

def __json__(self):
return self._truncate(asdict(self))

def _truncate(self, obj, max_items=2):
"""
Recursively traverse and truncate lists that exceed max_items length.
Only display the "... X more" message if at least 2 elements are hidden.
"""
if obj is None:
return None # Skip None values
elif isinstance(obj, list):
filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
if len(filtered_list) > max_items:
# Show the truncation message only if more than 1 item is hidden
remaining_items = len(filtered_list) - max_items
if remaining_items > 1:
return filtered_list[:max_items] + [f"... {remaining_items} more"]
else:
# If only 1 item remains, show it
return filtered_list
return filtered_list
elif isinstance(obj, dict):
# Recursively process dictionaries, omitting None values
return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
return obj


@dataclass
class QueryNamespacesResults:
usage: Usage
matches: List[ScoredVectorWithNamespace]

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(f"'{key}' not found in QueryNamespacesResults")

def get(self, key, default=None):
return getattr(self, key, default)

def __repr__(self):
return json.dumps(
{
"usage": self.usage.to_dict(),
"matches": [match.__json__() for match in self.matches],
},
indent=4,
)


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self):
super().__init__(
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
)


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 2:
raise QueryResultsAggregatorInvalidTopKError(top_k)
self.top_k = top_k
self.usage_read_units = 0
self.heap: List[Tuple[float, int, object, str]] = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False
self.final_results: Optional[QueryNamespacesResults] = None

def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
# the order of the scores in the results.
for i in range(1, len(matches)):
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
return False
return True

def _dotproduct_heap_item(self, match, ns):
return (match.get("score"), -self.insertion_counter, match, ns)

def _non_dotproduct_heap_item(self, match, ns):
return (-match.get("score"), -self.insertion_counter, match, ns)

def _process_matches(self, matches, ns, heap_item_fn):
for match in matches:
self.insertion_counter += 1
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
# Assume we have dotproduct scores sorted in descending order
if self.is_dotproduct and match["score"] < self.heap[0][0]:
# No further matches can improve the top-K heap
break
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
# No further matches can improve the top-K heap
break
heapq.heappushpop(self.heap, heap_item_fn(match, ns))

def add_results(self, results: Dict[str, Any]):
if self.read:
# This is mainly just to sanity check in test cases which get quite confusing
# if you read results twice due to the heap being emptied when constructing
# the ordered results.
raise ValueError("Results have already been read. Cannot add more results.")

matches = results.get("matches", [])
ns: str = results.get("namespace", "")
if isinstance(results, OpenAPIQueryResponse):
self.usage_read_units += results.usage.read_units
else:
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)

if len(matches) == 0:
return

if self.is_dotproduct is None:
if len(matches) == 1:
# This condition should match the second time we add results containing
# only one match. We need at least two matches in a single response in order
# to infer the similarity metric
raise QueryResultsAggregregatorNotEnoughResultsError()
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
self._process_matches(matches, ns, self._dotproduct_heap_item)
else:
self._process_matches(matches, ns, self._non_dotproduct_heap_item)

def get_results(self) -> QueryNamespacesResults:
if self.read:
if self.final_results is not None:
return self.final_results
else:
# I don't think this branch can ever actually be reached, but the type checker disagrees
raise ValueError("Results have already been read. Cannot get results again.")
self.read = True

self.final_results = QueryNamespacesResults(
usage=Usage(read_units=self.usage_read_units),
matches=[
ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
][::-1],
)
return self.final_results
Loading

0 comments on commit e668c89

Please sign in to comment.