diff --git a/pinecone/grpc/base.py b/pinecone/grpc/base.py index cc4ca2c6..8964e72d 100644 --- a/pinecone/grpc/base.py +++ b/pinecone/grpc/base.py @@ -10,6 +10,7 @@ from pinecone import Config from .config import GRPCClientConfig from .grpc_runner import GrpcRunner +from concurrent.futures import ThreadPoolExecutor from pinecone_plugin_interface import load_and_install as install_plugins @@ -29,10 +30,12 @@ def __init__( config: Config, channel: Optional[Channel] = None, grpc_config: Optional[GRPCClientConfig] = None, + pool_threads: Optional[int] = None, _endpoint_override: Optional[str] = None, ): self.config = config self.grpc_client_config = grpc_config or GRPCClientConfig() + self.pool_threads = pool_threads self._endpoint_override = _endpoint_override @@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs): except Exception as e: _logger.error(f"Error loading plugins in GRPCIndex: {e}") + @property + def threadpool_executor(self): + if self._pool is None: + pt = self.pool_threads or 10 + self._pool = ThreadPoolExecutor(max_workers=pt) + return self._pool + @property @abstractmethod def stub_class(self): diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index eba611b7..6791ae68 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -1,9 +1,11 @@ import logging -from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast +from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast from google.protobuf import json_format from tqdm.autonotebook import tqdm +from concurrent.futures import as_completed, Future + from .utils import ( dict_to_proto_struct, @@ -35,6 +37,7 @@ SparseValues as GRPCSparseValues, ) from pinecone import Vector as NonGRPCVector +from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub from .base import GRPCIndexBase from .future import PineconeGrpcFuture @@ -402,6 +405,49 @@ def query( json_response = json_format.MessageToDict(response) return parse_query_response(json_response, _check_type=False) + def query_namespaces( + self, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + **kwargs, + ) -> QueryNamespacesResults: + if namespaces is None or len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + + target_namespaces = set(namespaces) # dedup namespaces + futures = [ + self.threadpool_executor.submit( + self.query, + vector=vector, + namespace=ns, + top_k=overall_topk, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + async_req=False, + **kwargs, + ) + for ns in target_namespaces + ] + + only_futures = cast(Iterable[Future], futures) + for response in as_completed(only_futures): + aggregator.add_results(response.result()) + + final_results = aggregator.get_results() + return final_results + def update( self, id: str, diff --git a/pinecone/grpc/pinecone.py b/pinecone/grpc/pinecone.py index af6a8baa..c78481ff 100644 --- a/pinecone/grpc/pinecone.py +++ b/pinecone/grpc/pinecone.py @@ -124,6 +124,8 @@ def Index(self, name: str = "", host: str = "", **kwargs): # Use host if it is provided, otherwise get host from describe_index index_host = host or self.index_host_store.get_host(self.index_api, self.config, name) + pt = kwargs.pop("pool_threads", None) or self.pool_threads + config = ConfigBuilder.build( api_key=self.config.api_key, host=index_host, @@ -131,4 +133,4 @@ def Index(self, name: str = "", host: str = "", **kwargs): proxy_url=self.config.proxy_url, ssl_ca_certs=self.config.ssl_ca_certs, ) - return GRPCIndex(index_name=name, config=config, **kwargs) + return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs) diff --git a/tests/integration/data/test_query_namespaces.py b/tests/integration/data/test_query_namespaces.py index e52c58b0..414cea69 100644 --- a/tests/integration/data/test_query_namespaces.py +++ b/tests/integration/data/test_query_namespaces.py @@ -1,5 +1,4 @@ import pytest -import os from ..helpers import random_string, poll_stats_for_namespace from pinecone.data.query_results_aggregator import ( QueryResultsAggregatorInvalidTopKError, @@ -9,9 +8,6 @@ from pinecone import Vector -@pytest.mark.skipif( - os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest" -) class TestQueryNamespacesRest: def test_query_namespaces(self, idx): ns_prefix = random_string(5)