Skip to content

Commit

Permalink
[Refactor] Extract index request factory (#420)
Browse files Browse the repository at this point in the history
## Problem

With my asyncio client in progress (not in this diff), I found myself
needing to reuse a lot of the openapi request-building boilerplate.

## Solution

Rather than copy/paste and create a lot of duplication, I want to pull
that logic out into a separate class for building request objects. This
change should not break existing behavior in the Index client.

## Todo

This was pretty much just a mechanical extraction. I'd like to cleanup
and standardize some stuff regarding `_check_type`, an openapi param
that we're going to silly lengths to flip the default behavior on, in a
follow-up diff.

## Type of Change

- [x] None of the above: Refactor

## Test Plan

Should still have tests passing
  • Loading branch information
jhamon authored Dec 3, 2024
1 parent d9f365b commit 79d8032
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 152 deletions.
1 change: 1 addition & 0 deletions pinecone/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .index import *
from .import_error import Index, IndexClientInstantiationError
from .errors import (
VectorDictionaryMissingKeysError,
VectorDictionaryExcessKeysError,
Expand Down
30 changes: 30 additions & 0 deletions pinecone/data/import_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class IndexClientInstantiationError(Exception):
def __init__(self, index_args, index_kwargs):
formatted_args = ", ".join(map(repr, index_args))
formatted_kwargs = ", ".join(f"{key}={repr(value)}" for key, value in index_kwargs.items())
combined_args = ", ".join([a for a in [formatted_args, formatted_kwargs] if a.strip()])

self.message = f"""You are attempting to access the Index client directly from the pinecone module. The Index client must be instantiated through the parent Pinecone client instance so that it can inherit shared configurations such as API key.
INCORRECT USAGE:
```
import pinecone
pc = pinecone.Pinecone(api_key='your-api-key')
index = pinecone.Index({combined_args})
```
CORRECT USAGE:
```
from pinecone import Pinecone
pc = Pinecone(api_key='your-api-key')
index = pc.Index({combined_args})
```
"""
super().__init__(self.message)


class Index:
def __init__(self, *args, **kwargs):
raise IndexClientInstantiationError(args, kwargs)
188 changes: 36 additions & 152 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
SparseValues,
)
from .interfaces import IndexInterface
from .request_factory import IndexRequestFactory
from .features.bulk_import import ImportFeatureMixin
from ..utils import (
setup_openapi_client,
parse_non_empty_args,
build_plugin_setup_client,
validate_and_convert_errors,
)
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS

from multiprocessing.pool import ApplyResult
from concurrent.futures import as_completed
Expand All @@ -45,7 +46,6 @@
logger = logging.getLogger(__name__)

__all__ = [
"Index",
"_Index",
"FetchResponse",
"QueryRequest",
Expand All @@ -64,55 +64,12 @@
"SparseValues",
]

_OPENAPI_ENDPOINT_PARAMS = (
"_return_http_data_only",
"_preload_content",
"_request_timeout",
"_check_input_type",
"_check_return_type",
"_host_index",
"async_req",
"async_threadpool_executor",
)


def parse_query_response(response: QueryResponse):
response._data_store.pop("results", None)
return response


class IndexClientInstantiationError(Exception):
def __init__(self, index_args, index_kwargs):
formatted_args = ", ".join(map(repr, index_args))
formatted_kwargs = ", ".join(f"{key}={repr(value)}" for key, value in index_kwargs.items())
combined_args = ", ".join([a for a in [formatted_args, formatted_kwargs] if a.strip()])

self.message = f"""You are attempting to access the Index client directly from the pinecone module. The Index client must be instantiated through the parent Pinecone client instance so that it can inherit shared configurations such as API key.
INCORRECT USAGE:
```
import pinecone
pc = pinecone.Pinecone(api_key='your-api-key')
index = pinecone.Index({combined_args})
```
CORRECT USAGE:
```
from pinecone import Pinecone
pc = Pinecone(api_key='your-api-key')
index = pc.Index({combined_args})
```
"""
super().__init__(self.message)


class Index:
def __init__(self, *args, **kwargs):
raise IndexClientInstantiationError(args, kwargs)


class _Index(IndexInterface, ImportFeatureMixin):
"""
A client for interacting with a Pinecone index via REST API.
Expand Down Expand Up @@ -172,6 +129,9 @@ def _load_plugins(self):
except Exception as e:
logger.error(f"Error loading plugins in Index: {e}")

def _openapi_kwargs(self, kwargs):
return {k: v for k, v in kwargs.items() if k in OPENAPI_ENDPOINT_PARAMS}

def __enter__(self):
return self

Expand Down Expand Up @@ -221,19 +181,9 @@ def _upsert_batch(
_check_type: bool,
**kwargs,
) -> UpsertResponse:
args_dict = parse_non_empty_args([("namespace", namespace)])

def vec_builder(v):
return VectorFactory.build(v, check_type=_check_type)

return self._vector_api.upsert_vectors(
UpsertRequest(
vectors=list(map(vec_builder, vectors)),
**args_dict,
_check_type=_check_type,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
IndexRequestFactory.upsert_request(vectors, namespace, _check_type, **kwargs),
**self._openapi_kwargs(kwargs),
)

@staticmethod
Expand Down Expand Up @@ -277,22 +227,11 @@ def delete(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
**kwargs,
) -> Dict[str, Any]:
_check_type = kwargs.pop("_check_type", False)
args_dict = parse_non_empty_args(
[("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter)]
)

return self._vector_api.delete_vectors(
DeleteRequest(
**args_dict,
**{
k: v
for k, v in kwargs.items()
if k not in _OPENAPI_ENDPOINT_PARAMS and v is not None
},
_check_type=_check_type,
IndexRequestFactory.delete_request(
ids=ids, delete_all=delete_all, namespace=namespace, filter=filter, **kwargs
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
**self._openapi_kwargs(kwargs),
)

@validate_and_convert_errors
Expand Down Expand Up @@ -354,35 +293,18 @@ def _query(
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
)

if vector is not None and id is not None:
raise ValueError("Cannot specify both `id` and `vector`")

_check_type = kwargs.pop("_check_type", False)

sparse_vector = self._parse_sparse_values_arg(sparse_vector)
args_dict = parse_non_empty_args(
[
("vector", vector),
("id", id),
("queries", None),
("top_k", top_k),
("namespace", namespace),
("filter", filter),
("include_values", include_values),
("include_metadata", include_metadata),
("sparse_vector", sparse_vector),
]
)

response = self._vector_api.query_vectors(
QueryRequest(
**args_dict,
_check_type=_check_type,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
request = IndexRequestFactory.query_request(
top_k=top_k,
vector=vector,
id=id,
namespace=namespace,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
**kwargs,
)
return response
return self._vector_api.query_vectors(request, **self._openapi_kwargs(kwargs))

@validate_and_convert_errors
def query_namespaces(
Expand Down Expand Up @@ -445,40 +367,25 @@ def update(
] = None,
**kwargs,
) -> Dict[str, Any]:
_check_type = kwargs.pop("_check_type", False)
sparse_values = self._parse_sparse_values_arg(sparse_values)
args_dict = parse_non_empty_args(
[
("values", values),
("set_metadata", set_metadata),
("namespace", namespace),
("sparse_values", sparse_values),
]
)
return self._vector_api.update_vector(
UpdateRequest(
IndexRequestFactory.update_request(
id=id,
**args_dict,
_check_type=_check_type,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
values=values,
set_metadata=set_metadata,
namespace=namespace,
sparse_values=sparse_values,
**kwargs,
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
**self._openapi_kwargs(kwargs),
)

@validate_and_convert_errors
def describe_index_stats(
self, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs
) -> DescribeIndexStatsResponse:
_check_type = kwargs.pop("_check_type", False)
args_dict = parse_non_empty_args([("filter", filter)])

return self._vector_api.describe_index_stats(
DescribeIndexStatsRequest(
**args_dict,
**{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS},
_check_type=_check_type,
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
IndexRequestFactory.describe_index_stats_request(filter, **kwargs),
**self._openapi_kwargs(kwargs),
)

@validate_and_convert_errors
Expand All @@ -490,13 +397,12 @@ def list_paginated(
namespace: Optional[str] = None,
**kwargs,
) -> ListResponse:
args_dict = parse_non_empty_args(
[
("prefix", prefix),
("limit", limit),
("namespace", namespace),
("pagination_token", pagination_token),
]
args_dict = IndexRequestFactory.list_paginated_args(
prefix=prefix,
limit=limit,
pagination_token=pagination_token,
namespace=namespace,
**kwargs,
)
return self._vector_api.list_vectors(**args_dict, **kwargs)

Expand All @@ -512,25 +418,3 @@ def list(self, **kwargs):
kwargs.update({"pagination_token": results.pagination.next})
else:
done = True

@staticmethod
def _parse_sparse_values_arg(
sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]],
) -> Optional[SparseValues]:
if sparse_values is None:
return None

if isinstance(sparse_values, SparseValues):
return sparse_values

if (
not isinstance(sparse_values, dict)
or "indices" not in sparse_values
or "values" not in sparse_values
):
raise ValueError(
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
f"Received: {sparse_values}"
)

return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"])
Loading

0 comments on commit 79d8032

Please sign in to comment.