Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Extract index request factory #420

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pinecone/data/import_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class IndexClientInstantiationError(Exception):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I was already in a mode of pulling things out of index.py, I decided to move this to a different file as well.

def __init__(self, index_args, index_kwargs):
formatted_args = ", ".join(map(repr, index_args))
formatted_kwargs = ", ".join(f"{key}={repr(value)}" for key, value in index_kwargs.items())
combined_args = ", ".join([a for a in [formatted_args, formatted_kwargs] if a.strip()])

self.message = f"""You are attempting to access the Index client directly from the pinecone module. The Index client must be instantiated through the parent Pinecone client instance so that it can inherit shared configurations such as API key.

INCORRECT USAGE:
```
import pinecone

pc = pinecone.Pinecone(api_key='your-api-key')
index = pinecone.Index({combined_args})
```

CORRECT USAGE:
```
from pinecone import Pinecone

pc = Pinecone(api_key='your-api-key')
index = pc.Index({combined_args})
```
"""
super().__init__(self.message)


class Index:
def __init__(self, *args, **kwargs):
raise IndexClientInstantiationError(args, kwargs)
188 changes: 36 additions & 152 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
SparseValues,
)
from .interfaces import IndexInterface
from .request_factory import IndexRequestFactory
from .features.bulk_import import ImportFeatureMixin
from ..utils import (
setup_openapi_client,
parse_non_empty_args,
build_plugin_setup_client,
validate_and_convert_errors,
)
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS

from multiprocessing.pool import ApplyResult
from concurrent.futures import as_completed
Expand All @@ -45,7 +46,6 @@
logger = logging.getLogger(__name__)

__all__ = [
"Index",
"_Index",
"FetchResponse",
"QueryRequest",
Expand All @@ -64,55 +64,12 @@
"SparseValues",
]

_OPENAPI_ENDPOINT_PARAMS = (
"_return_http_data_only",
"_preload_content",
"_request_timeout",
"_check_input_type",
"_check_return_type",
"_host_index",
"async_req",
"async_threadpool_executor",
)


def parse_query_response(response: QueryResponse):
response._data_store.pop("results", None)
return response


class IndexClientInstantiationError(Exception):
def __init__(self, index_args, index_kwargs):
formatted_args = ", ".join(map(repr, index_args))
formatted_kwargs = ", ".join(f"{key}={repr(value)}" for key, value in index_kwargs.items())
combined_args = ", ".join([a for a in [formatted_args, formatted_kwargs] if a.strip()])

self.message = f"""You are attempting to access the Index client directly from the pinecone module. The Index client must be instantiated through the parent Pinecone client instance so that it can inherit shared configurations such as API key.

INCORRECT USAGE:
```
import pinecone

pc = pinecone.Pinecone(api_key='your-api-key')
index = pinecone.Index({combined_args})
```

CORRECT USAGE:
```
from pinecone import Pinecone

pc = Pinecone(api_key='your-api-key')
index = pc.Index({combined_args})
```
"""
super().__init__(self.message)


class Index:
def __init__(self, *args, **kwargs):
raise IndexClientInstantiationError(args, kwargs)


class _Index(IndexInterface, ImportFeatureMixin):
"""
A client for interacting with a Pinecone index via REST API.
Expand Down Expand Up @@ -172,6 +129,9 @@ def _load_plugins(self):
except Exception as e:
logger.error(f"Error loading plugins in Index: {e}")

def _openapi_kwargs(self, kwargs):
return {k: v for k, v in kwargs.items() if k in OPENAPI_ENDPOINT_PARAMS}

def __enter__(self):
return self

Expand Down Expand Up @@ -221,19 +181,9 @@ def _upsert_batch(
_check_type: bool,
**kwargs,
) -> UpsertResponse:
args_dict = parse_non_empty_args([("namespace", namespace)])

def vec_builder(v):
return VectorFactory.build(v, check_type=_check_type)

return self._vector_api.upsert_vectors(
UpsertRequest(
vectors=list(map(vec_builder, vectors)),
**args_dict,
_check_type=_check_type,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
IndexRequestFactory.upsert_request(vectors, namespace, _check_type, **kwargs),
**self._openapi_kwargs(kwargs),
)

@staticmethod
Expand Down Expand Up @@ -277,22 +227,11 @@ def delete(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
**kwargs,
) -> Dict[str, Any]:
_check_type = kwargs.pop("_check_type", False)
args_dict = parse_non_empty_args(
[("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter)]
)

return self._vector_api.delete_vectors(
DeleteRequest(
**args_dict,
**{
k: v
for k, v in kwargs.items()
if k not in _OPENAPI_ENDPOINT_PARAMS and v is not None
},
_check_type=_check_type,
IndexRequestFactory.delete_request(
ids=ids, delete_all=delete_all, namespace=namespace, filter=filter, **kwargs
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
**self._openapi_kwargs(kwargs),
)

@validate_and_convert_errors
Expand Down Expand Up @@ -354,35 +293,18 @@ def _query(
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
)

if vector is not None and id is not None:
raise ValueError("Cannot specify both `id` and `vector`")

_check_type = kwargs.pop("_check_type", False)

sparse_vector = self._parse_sparse_values_arg(sparse_vector)
args_dict = parse_non_empty_args(
[
("vector", vector),
("id", id),
("queries", None),
("top_k", top_k),
("namespace", namespace),
("filter", filter),
("include_values", include_values),
("include_metadata", include_metadata),
("sparse_vector", sparse_vector),
]
)

response = self._vector_api.query_vectors(
QueryRequest(
**args_dict,
_check_type=_check_type,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
request = IndexRequestFactory.query_request(
top_k=top_k,
vector=vector,
id=id,
namespace=namespace,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
**kwargs,
)
return response
return self._vector_api.query_vectors(request, **self._openapi_kwargs(kwargs))

@validate_and_convert_errors
def query_namespaces(
Expand Down Expand Up @@ -445,40 +367,25 @@ def update(
] = None,
**kwargs,
) -> Dict[str, Any]:
_check_type = kwargs.pop("_check_type", False)
sparse_values = self._parse_sparse_values_arg(sparse_values)
args_dict = parse_non_empty_args(
[
("values", values),
("set_metadata", set_metadata),
("namespace", namespace),
("sparse_values", sparse_values),
]
)
return self._vector_api.update_vector(
UpdateRequest(
IndexRequestFactory.update_request(
id=id,
**args_dict,
_check_type=_check_type,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
values=values,
set_metadata=set_metadata,
namespace=namespace,
sparse_values=sparse_values,
**kwargs,
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
**self._openapi_kwargs(kwargs),
)

@validate_and_convert_errors
def describe_index_stats(
self, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs
) -> DescribeIndexStatsResponse:
_check_type = kwargs.pop("_check_type", False)
args_dict = parse_non_empty_args([("filter", filter)])

return self._vector_api.describe_index_stats(
DescribeIndexStatsRequest(
**args_dict,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
_check_type=_check_type,
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
IndexRequestFactory.describe_index_stats_request(filter, **kwargs),
**self._openapi_kwargs(kwargs),
)

@validate_and_convert_errors
Expand All @@ -490,13 +397,12 @@ def list_paginated(
namespace: Optional[str] = None,
**kwargs,
) -> ListResponse:
args_dict = parse_non_empty_args(
[
("prefix", prefix),
("limit", limit),
("namespace", namespace),
("pagination_token", pagination_token),
]
args_dict = IndexRequestFactory.list_paginated_args(
prefix=prefix,
limit=limit,
pagination_token=pagination_token,
namespace=namespace,
**kwargs,
)
return self._vector_api.list_vectors(**args_dict, **kwargs)

Expand All @@ -512,25 +418,3 @@ def list(self, **kwargs):
kwargs.update({"pagination_token": results.pagination.next})
else:
done = True

@staticmethod
def _parse_sparse_values_arg(
sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]],
) -> Optional[SparseValues]:
if sparse_values is None:
return None

if isinstance(sparse_values, SparseValues):
return sparse_values

if (
not isinstance(sparse_values, dict)
or "indices" not in sparse_values
or "values" not in sparse_values
):
raise ValueError(
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
f"Received: {sparse_values}"
)

return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"])
Loading
Loading