Skip to content

Commit

Permalink
Implement query_namespaces over grpc (#416)
Browse files Browse the repository at this point in the history
## Problem

I want to maintain parity across REST / GRPC implementations. This PR
adds a query_namespaces implementation for the GRPC index client.

## Solution

Use a ThreadPoolExecutor to execute queries in parallel, then aggregate
the results QueryResultsAggregator.

## Usage

```python
import random
from pinecone.grpc import PineconeGRPC

pc = PineconeGRPC(api_key="key")
index = pc.Index(host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io", pool_threads=25)

query_vec = [random.random() for i in range(1024)]
combined_results = index.query_namespaces(
    vector=query_vec,
    namespaces=["ns1", "ns2", "ns3", "ns4"],
    include_values=False,
    include_metadata=True,
    filter={"genre": {"$eq": "drama"}},
    top_k=50
)

for vec in combined_results.matches:
    print(vec.get('id'), vec.get('score'))
print(combined_results.usage)
```

## Type of Change

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)

## Test Plan

Describe specific steps for validating this change.
  • Loading branch information
jhamon authored Nov 13, 2024
1 parent eade7dd commit ab28227
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
10 changes: 10 additions & 0 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pinecone import Config
from .config import GRPCClientConfig
from .grpc_runner import GrpcRunner
from concurrent.futures import ThreadPoolExecutor

from pinecone_plugin_interface import load_and_install as install_plugins

Expand All @@ -29,10 +30,12 @@ def __init__(
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
pool_threads: Optional[int] = None,
_endpoint_override: Optional[str] = None,
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.pool_threads = pool_threads

self._endpoint_override = _endpoint_override

Expand All @@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs):
except Exception as e:
_logger.error(f"Error loading plugins in GRPCIndex: {e}")

@property
def threadpool_executor(self):
if self._pool is None:
pt = self.pool_threads or 10
self._pool = ThreadPoolExecutor(max_workers=pt)
return self._pool

@property
@abstractmethod
def stub_class(self):
Expand Down
48 changes: 47 additions & 1 deletion pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast

from google.protobuf import json_format

from tqdm.autonotebook import tqdm
from concurrent.futures import as_completed, Future


from .utils import (
dict_to_proto_struct,
Expand Down Expand Up @@ -35,6 +37,7 @@
SparseValues as GRPCSparseValues,
)
from pinecone import Vector as NonGRPCVector
from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture
Expand Down Expand Up @@ -402,6 +405,49 @@ def query(
json_response = json_format.MessageToDict(response)
return parse_query_response(json_response, _check_type=False)

def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> QueryNamespacesResults:
if namespaces is None or len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)

target_namespaces = set(namespaces) # dedup namespaces
futures = [
self.threadpool_executor.submit(
self.query,
vector=vector,
namespace=ns,
top_k=overall_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=False,
**kwargs,
)
for ns in target_namespaces
]

only_futures = cast(Iterable[Future], futures)
for response in as_completed(only_futures):
aggregator.add_results(response.result())

final_results = aggregator.get_results()
return final_results

def update(
self,
id: str,
Expand Down
4 changes: 3 additions & 1 deletion pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ def Index(self, name: str = "", host: str = "", **kwargs):
# Use host if it is provided, otherwise get host from describe_index
index_host = host or self.index_host_store.get_host(self.index_api, self.config, name)

pt = kwargs.pop("pool_threads", None) or self.pool_threads

config = ConfigBuilder.build(
api_key=self.config.api_key,
host=index_host,
source_tag=self.config.source_tag,
proxy_url=self.config.proxy_url,
ssl_ca_certs=self.config.ssl_ca_certs,
)
return GRPCIndex(index_name=name, config=config, **kwargs)
return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs)
4 changes: 0 additions & 4 deletions tests/integration/data/test_query_namespaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import os
from ..helpers import random_string, poll_stats_for_namespace
from pinecone.data.query_results_aggregator import (
QueryResultsAggregatorInvalidTopKError,
Expand All @@ -9,9 +8,6 @@
from pinecone import Vector


@pytest.mark.skipif(
os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest"
)
class TestQueryNamespacesRest:
def test_query_namespaces(self, idx):
ns_prefix = random_string(5)
Expand Down

0 comments on commit ab28227

Please sign in to comment.