diff --git a/pinecone/data/__init__.py b/pinecone/data/__init__.py index a2d6c1a1..02a67767 100644 --- a/pinecone/data/__init__.py +++ b/pinecone/data/__init__.py @@ -1,4 +1,5 @@ from .index import * +from .import_error import Index, IndexClientInstantiationError from .errors import ( VectorDictionaryMissingKeysError, VectorDictionaryExcessKeysError, diff --git a/pinecone/data/import_error.py b/pinecone/data/import_error.py new file mode 100644 index 00000000..751872fd --- /dev/null +++ b/pinecone/data/import_error.py @@ -0,0 +1,30 @@ +class IndexClientInstantiationError(Exception): + def __init__(self, index_args, index_kwargs): + formatted_args = ", ".join(map(repr, index_args)) + formatted_kwargs = ", ".join(f"{key}={repr(value)}" for key, value in index_kwargs.items()) + combined_args = ", ".join([a for a in [formatted_args, formatted_kwargs] if a.strip()]) + + self.message = f"""You are attempting to access the Index client directly from the pinecone module. The Index client must be instantiated through the parent Pinecone client instance so that it can inherit shared configurations such as API key. + + INCORRECT USAGE: + ``` + import pinecone + + pc = pinecone.Pinecone(api_key='your-api-key') + index = pinecone.Index({combined_args}) + ``` + + CORRECT USAGE: + ``` + from pinecone import Pinecone + + pc = Pinecone(api_key='your-api-key') + index = pc.Index({combined_args}) + ``` + """ + super().__init__(self.message) + + +class Index: + def __init__(self, *args, **kwargs): + raise IndexClientInstantiationError(args, kwargs) diff --git a/pinecone/data/index.py b/pinecone/data/index.py index 97063cd0..3f74568d 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -27,6 +27,7 @@ SparseValues, ) from .interfaces import IndexInterface +from .request_factory import IndexRequestFactory from .features.bulk_import import ImportFeatureMixin from ..utils import ( setup_openapi_client, @@ -34,8 +35,8 @@ build_plugin_setup_client, validate_and_convert_errors, ) -from .vector_factory import VectorFactory from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults +from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS from multiprocessing.pool import ApplyResult from concurrent.futures import as_completed @@ -45,7 +46,6 @@ logger = logging.getLogger(__name__) __all__ = [ - "Index", "_Index", "FetchResponse", "QueryRequest", @@ -64,55 +64,12 @@ "SparseValues", ] -_OPENAPI_ENDPOINT_PARAMS = ( - "_return_http_data_only", - "_preload_content", - "_request_timeout", - "_check_input_type", - "_check_return_type", - "_host_index", - "async_req", - "async_threadpool_executor", -) - def parse_query_response(response: QueryResponse): response._data_store.pop("results", None) return response -class IndexClientInstantiationError(Exception): - def __init__(self, index_args, index_kwargs): - formatted_args = ", ".join(map(repr, index_args)) - formatted_kwargs = ", ".join(f"{key}={repr(value)}" for key, value in index_kwargs.items()) - combined_args = ", ".join([a for a in [formatted_args, formatted_kwargs] if a.strip()]) - - self.message = f"""You are attempting to access the Index client directly from the pinecone module. The Index client must be instantiated through the parent Pinecone client instance so that it can inherit shared configurations such as API key. - - INCORRECT USAGE: - ``` - import pinecone - - pc = pinecone.Pinecone(api_key='your-api-key') - index = pinecone.Index({combined_args}) - ``` - - CORRECT USAGE: - ``` - from pinecone import Pinecone - - pc = Pinecone(api_key='your-api-key') - index = pc.Index({combined_args}) - ``` - """ - super().__init__(self.message) - - -class Index: - def __init__(self, *args, **kwargs): - raise IndexClientInstantiationError(args, kwargs) - - class _Index(IndexInterface, ImportFeatureMixin): """ A client for interacting with a Pinecone index via REST API. @@ -172,6 +129,9 @@ def _load_plugins(self): except Exception as e: logger.error(f"Error loading plugins in Index: {e}") + def _openapi_kwargs(self, kwargs): + return {k: v for k, v in kwargs.items() if k in OPENAPI_ENDPOINT_PARAMS} + def __enter__(self): return self @@ -221,19 +181,9 @@ def _upsert_batch( _check_type: bool, **kwargs, ) -> UpsertResponse: - args_dict = parse_non_empty_args([("namespace", namespace)]) - - def vec_builder(v): - return VectorFactory.build(v, check_type=_check_type) - return self._vector_api.upsert_vectors( - UpsertRequest( - vectors=list(map(vec_builder, vectors)), - **args_dict, - _check_type=_check_type, - **{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS}, - ), - **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, + IndexRequestFactory.upsert_request(vectors, namespace, _check_type, **kwargs), + **self._openapi_kwargs(kwargs), ) @staticmethod @@ -277,22 +227,11 @@ def delete( filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs, ) -> Dict[str, Any]: - _check_type = kwargs.pop("_check_type", False) - args_dict = parse_non_empty_args( - [("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter)] - ) - return self._vector_api.delete_vectors( - DeleteRequest( - **args_dict, - **{ - k: v - for k, v in kwargs.items() - if k not in _OPENAPI_ENDPOINT_PARAMS and v is not None - }, - _check_type=_check_type, + IndexRequestFactory.delete_request( + ids=ids, delete_all=delete_all, namespace=namespace, filter=filter, **kwargs ), - **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, + **self._openapi_kwargs(kwargs), ) @validate_and_convert_errors @@ -354,35 +293,18 @@ def _query( "The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')" ) - if vector is not None and id is not None: - raise ValueError("Cannot specify both `id` and `vector`") - - _check_type = kwargs.pop("_check_type", False) - - sparse_vector = self._parse_sparse_values_arg(sparse_vector) - args_dict = parse_non_empty_args( - [ - ("vector", vector), - ("id", id), - ("queries", None), - ("top_k", top_k), - ("namespace", namespace), - ("filter", filter), - ("include_values", include_values), - ("include_metadata", include_metadata), - ("sparse_vector", sparse_vector), - ] - ) - - response = self._vector_api.query_vectors( - QueryRequest( - **args_dict, - _check_type=_check_type, - **{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS}, - ), - **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, + request = IndexRequestFactory.query_request( + top_k=top_k, + vector=vector, + id=id, + namespace=namespace, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + **kwargs, ) - return response + return self._vector_api.query_vectors(request, **self._openapi_kwargs(kwargs)) @validate_and_convert_errors def query_namespaces( @@ -445,40 +367,25 @@ def update( ] = None, **kwargs, ) -> Dict[str, Any]: - _check_type = kwargs.pop("_check_type", False) - sparse_values = self._parse_sparse_values_arg(sparse_values) - args_dict = parse_non_empty_args( - [ - ("values", values), - ("set_metadata", set_metadata), - ("namespace", namespace), - ("sparse_values", sparse_values), - ] - ) return self._vector_api.update_vector( - UpdateRequest( + IndexRequestFactory.update_request( id=id, - **args_dict, - _check_type=_check_type, - **{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS}, + values=values, + set_metadata=set_metadata, + namespace=namespace, + sparse_values=sparse_values, + **kwargs, ), - **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, + **self._openapi_kwargs(kwargs), ) @validate_and_convert_errors def describe_index_stats( self, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs ) -> DescribeIndexStatsResponse: - _check_type = kwargs.pop("_check_type", False) - args_dict = parse_non_empty_args([("filter", filter)]) - return self._vector_api.describe_index_stats( - DescribeIndexStatsRequest( - **args_dict, - **{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS}, - _check_type=_check_type, - ), - **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, + IndexRequestFactory.describe_index_stats_request(filter, **kwargs), + **self._openapi_kwargs(kwargs), ) @validate_and_convert_errors @@ -490,13 +397,12 @@ def list_paginated( namespace: Optional[str] = None, **kwargs, ) -> ListResponse: - args_dict = parse_non_empty_args( - [ - ("prefix", prefix), - ("limit", limit), - ("namespace", namespace), - ("pagination_token", pagination_token), - ] + args_dict = IndexRequestFactory.list_paginated_args( + prefix=prefix, + limit=limit, + pagination_token=pagination_token, + namespace=namespace, + **kwargs, ) return self._vector_api.list_vectors(**args_dict, **kwargs) @@ -512,25 +418,3 @@ def list(self, **kwargs): kwargs.update({"pagination_token": results.pagination.next}) else: done = True - - @staticmethod - def _parse_sparse_values_arg( - sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]], - ) -> Optional[SparseValues]: - if sparse_values is None: - return None - - if isinstance(sparse_values, SparseValues): - return sparse_values - - if ( - not isinstance(sparse_values, dict) - or "indices" not in sparse_values - or "values" not in sparse_values - ): - raise ValueError( - "Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}." - f"Received: {sparse_values}" - ) - - return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"]) diff --git a/pinecone/data/request_factory.py b/pinecone/data/request_factory.py new file mode 100644 index 00000000..20dd5a42 --- /dev/null +++ b/pinecone/data/request_factory.py @@ -0,0 +1,170 @@ +import logging +from typing import Union, List, Optional, Dict, Any + +from pinecone.core.openapi.db_data.models import ( + QueryRequest, + UpsertRequest, + Vector, + DeleteRequest, + UpdateRequest, + DescribeIndexStatsRequest, + SparseValues, +) +from ..utils import parse_non_empty_args +from .vector_factory import VectorFactory +from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS + +logger = logging.getLogger(__name__) + + +def parse_sparse_values_arg( + sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]], +) -> Optional[SparseValues]: + if sparse_values is None: + return None + + if isinstance(sparse_values, SparseValues): + return sparse_values + + if ( + not isinstance(sparse_values, dict) + or "indices" not in sparse_values + or "values" not in sparse_values + ): + raise ValueError( + "Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}." + f"Received: {sparse_values}" + ) + + return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"]) + + +def non_openapi_kwargs(kwargs): + return {k: v for k, v in kwargs.items() if k not in OPENAPI_ENDPOINT_PARAMS} + + +class IndexRequestFactory: + @staticmethod + def query_request( + top_k: int, + vector: Optional[List[float]] = None, + id: Optional[str] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[ + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] + ] = None, + **kwargs, + ) -> QueryRequest: + if vector is not None and id is not None: + raise ValueError("Cannot specify both `id` and `vector`") + + sparse_vector = parse_sparse_values_arg(sparse_vector) + args_dict = parse_non_empty_args( + [ + ("vector", vector), + ("id", id), + ("queries", None), + ("top_k", top_k), + ("namespace", namespace), + ("filter", filter), + ("include_values", include_values), + ("include_metadata", include_metadata), + ("sparse_vector", sparse_vector), + ] + ) + + return QueryRequest( + **args_dict, _check_type=kwargs.pop("_check_type", False), **non_openapi_kwargs(kwargs) + ) + + @staticmethod + def upsert_request( + vectors: Union[List[Vector], List[tuple], List[dict]], + namespace: Optional[str], + _check_type: bool, + **kwargs, + ) -> UpsertRequest: + args_dict = parse_non_empty_args([("namespace", namespace)]) + + def vec_builder(v): + return VectorFactory.build(v, check_type=_check_type) + + return UpsertRequest( + vectors=list(map(vec_builder, vectors)), + **args_dict, + _check_type=_check_type, + **non_openapi_kwargs(kwargs), + ) + + @staticmethod + def delete_request( + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + **kwargs, + ) -> DeleteRequest: + _check_type = kwargs.pop("_check_type", False) + args_dict = parse_non_empty_args( + [("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter)] + ) + return DeleteRequest(**args_dict, **non_openapi_kwargs(kwargs), _check_type=_check_type) + + @staticmethod + def update_request( + id: str, + values: Optional[List[float]] = None, + set_metadata: Optional[ + Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]] + ] = None, + namespace: Optional[str] = None, + sparse_values: Optional[ + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] + ] = None, + **kwargs, + ) -> UpdateRequest: + _check_type = kwargs.pop("_check_type", False) + sparse_values = parse_sparse_values_arg(sparse_values) + args_dict = parse_non_empty_args( + [ + ("values", values), + ("set_metadata", set_metadata), + ("namespace", namespace), + ("sparse_values", sparse_values), + ] + ) + + return UpdateRequest( + id=id, **args_dict, _check_type=_check_type, **non_openapi_kwargs(kwargs) + ) + + @staticmethod + def describe_index_stats_request( + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs + ) -> DescribeIndexStatsRequest: + _check_type = kwargs.pop("_check_type", False) + args_dict = parse_non_empty_args([("filter", filter)]) + + return DescribeIndexStatsRequest( + **args_dict, **non_openapi_kwargs(kwargs), _check_type=_check_type + ) + + @staticmethod + def list_paginated_args( + prefix: Optional[str] = None, + limit: Optional[int] = None, + pagination_token: Optional[str] = None, + namespace: Optional[str] = None, + **kwargs, + ) -> Dict[str, Any]: + return parse_non_empty_args( + [ + ("prefix", prefix), + ("limit", limit), + ("namespace", namespace), + ("pagination_token", pagination_token), + ] + ) diff --git a/pinecone/openapi_support/__init__.py b/pinecone/openapi_support/__init__.py index 8bf227bb..bfd498f7 100644 --- a/pinecone/openapi_support/__init__.py +++ b/pinecone/openapi_support/__init__.py @@ -34,6 +34,7 @@ none_type, ) from .rest import RESTClientObject, RESTResponse +from .constants import OPENAPI_ENDPOINT_PARAMS from datetime import date, datetime # noqa: F401 from dateutil.parser import parse diff --git a/pinecone/openapi_support/constants.py b/pinecone/openapi_support/constants.py new file mode 100644 index 00000000..92864e5f --- /dev/null +++ b/pinecone/openapi_support/constants.py @@ -0,0 +1,10 @@ +OPENAPI_ENDPOINT_PARAMS = ( + "_return_http_data_only", + "_preload_content", + "_request_timeout", + "_check_input_type", + "_check_return_type", + "_host_index", + "async_req", + "async_threadpool_executor", +)