From ab28227347c715d3638e300e9ae50d67da9f9ee3 Mon Sep 17 00:00:00 2001 From: Jennifer Hamon Date: Wed, 13 Nov 2024 11:07:39 -0500 Subject: [PATCH] Implement query_namespaces over grpc (#416) ## Problem I want to maintain parity across REST / GRPC implementations. This PR adds a query_namespaces implementation for the GRPC index client. ## Solution Use a ThreadPoolExecutor to execute queries in parallel, then aggregate the results QueryResultsAggregator. ## Usage ```python import random from pinecone.grpc import PineconeGRPC pc = PineconeGRPC(api_key="key") index = pc.Index(host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io", pool_threads=25) query_vec = [random.random() for i in range(1024)] combined_results = index.query_namespaces( vector=query_vec, namespaces=["ns1", "ns2", "ns3", "ns4"], include_values=False, include_metadata=True, filter={"genre": {"$eq": "drama"}}, top_k=50 ) for vec in combined_results.matches: print(vec.get('id'), vec.get('score')) print(combined_results.usage) ``` ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan Describe specific steps for validating this change. --- pinecone/grpc/base.py | 10 ++++ pinecone/grpc/index_grpc.py | 48 ++++++++++++++++++- pinecone/grpc/pinecone.py | 4 +- .../integration/data/test_query_namespaces.py | 4 -- 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/pinecone/grpc/base.py b/pinecone/grpc/base.py index cc4ca2c6..8964e72d 100644 --- a/pinecone/grpc/base.py +++ b/pinecone/grpc/base.py @@ -10,6 +10,7 @@ from pinecone import Config from .config import GRPCClientConfig from .grpc_runner import GrpcRunner +from concurrent.futures import ThreadPoolExecutor from pinecone_plugin_interface import load_and_install as install_plugins @@ -29,10 +30,12 @@ def __init__( config: Config, channel: Optional[Channel] = None, grpc_config: Optional[GRPCClientConfig] = None, + pool_threads: Optional[int] = None, _endpoint_override: Optional[str] = None, ): self.config = config self.grpc_client_config = grpc_config or GRPCClientConfig() + self.pool_threads = pool_threads self._endpoint_override = _endpoint_override @@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs): except Exception as e: _logger.error(f"Error loading plugins in GRPCIndex: {e}") + @property + def threadpool_executor(self): + if self._pool is None: + pt = self.pool_threads or 10 + self._pool = ThreadPoolExecutor(max_workers=pt) + return self._pool + @property @abstractmethod def stub_class(self): diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index eba611b7..6791ae68 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -1,9 +1,11 @@ import logging -from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast +from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast from google.protobuf import json_format from tqdm.autonotebook import tqdm +from concurrent.futures import as_completed, Future + from .utils import ( dict_to_proto_struct, @@ -35,6 +37,7 @@ SparseValues as GRPCSparseValues, ) from pinecone import Vector as NonGRPCVector +from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub from .base import GRPCIndexBase from .future import PineconeGrpcFuture @@ -402,6 +405,49 @@ def query( json_response = json_format.MessageToDict(response) return parse_query_response(json_response, _check_type=False) + def query_namespaces( + self, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + **kwargs, + ) -> QueryNamespacesResults: + if namespaces is None or len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + + target_namespaces = set(namespaces) # dedup namespaces + futures = [ + self.threadpool_executor.submit( + self.query, + vector=vector, + namespace=ns, + top_k=overall_topk, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + async_req=False, + **kwargs, + ) + for ns in target_namespaces + ] + + only_futures = cast(Iterable[Future], futures) + for response in as_completed(only_futures): + aggregator.add_results(response.result()) + + final_results = aggregator.get_results() + return final_results + def update( self, id: str, diff --git a/pinecone/grpc/pinecone.py b/pinecone/grpc/pinecone.py index af6a8baa..c78481ff 100644 --- a/pinecone/grpc/pinecone.py +++ b/pinecone/grpc/pinecone.py @@ -124,6 +124,8 @@ def Index(self, name: str = "", host: str = "", **kwargs): # Use host if it is provided, otherwise get host from describe_index index_host = host or self.index_host_store.get_host(self.index_api, self.config, name) + pt = kwargs.pop("pool_threads", None) or self.pool_threads + config = ConfigBuilder.build( api_key=self.config.api_key, host=index_host, @@ -131,4 +133,4 @@ def Index(self, name: str = "", host: str = "", **kwargs): proxy_url=self.config.proxy_url, ssl_ca_certs=self.config.ssl_ca_certs, ) - return GRPCIndex(index_name=name, config=config, **kwargs) + return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs) diff --git a/tests/integration/data/test_query_namespaces.py b/tests/integration/data/test_query_namespaces.py index e52c58b0..414cea69 100644 --- a/tests/integration/data/test_query_namespaces.py +++ b/tests/integration/data/test_query_namespaces.py @@ -1,5 +1,4 @@ import pytest -import os from ..helpers import random_string, poll_stats_for_namespace from pinecone.data.query_results_aggregator import ( QueryResultsAggregatorInvalidTopKError, @@ -9,9 +8,6 @@ from pinecone import Vector -@pytest.mark.skipif( - os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest" -) class TestQueryNamespacesRest: def test_query_namespaces(self, idx): ns_prefix = random_string(5)