Skip to content

Commit

Permalink
query_namespaces performance improvements (#417)
Browse files Browse the repository at this point in the history
## Problem

Want to improve the performance of the rest implementation of
`query_namespaces`

## Solution

- Add `pytest-benchmark` dev dependency and some basic performance tests
to interrogate the impact of certain changes. For now these are only run
on my local machine, but in the future these could potentially be
expanded into an automated suite.
- Pass `_preload_content=False` to tell the underlying generated code
not to instantiate response objects for all the intermediate results.
- Use `ThreadPoolExecutor` instead of older `ThreadPool` implementation
from multiprocessing. This involved some changes to the generated code,
but the benefit of this approach is that you get back a
`concurrent.futures.Future` instead of an `ApplyResult` which is much
more ergonomic. I'm planning to extract the edited files out of the code
gen process very shortly, so there shouldn't be a concern about
modifying generated files in this case. I gated this approach behind a
new kwarg, `async_threadpool_executor`, that lives alongside
`async_req`; eventually I would like to replace all usage of
`async_req`'s ThreadPool with ThreadPoolExecutor to bring the rest and
grpc implementations closer together, but I can't do that in this PR
without creating a breaking change.

The net effect of these changes seems to be about ~18% performance
improvement.

## Type of Change

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)
  • Loading branch information
jhamon authored Nov 13, 2024
1 parent ab28227 commit 4a99468
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 6 deletions.
36 changes: 36 additions & 0 deletions pinecone/core/openapi/shared/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import atexit
import mimetypes
from multiprocessing.pool import ThreadPool
from concurrent.futures import ThreadPoolExecutor
import io
import os
import re
Expand Down Expand Up @@ -70,6 +71,7 @@ class ApiClient(object):
"""

_pool = None
_threadpool_executor = None

def __init__(self, configuration=None, header_name=None, header_value=None, cookie=None, pool_threads=1):
if configuration is None:
Expand All @@ -92,6 +94,9 @@ def __exit__(self, exc_type, exc_value, traceback):
self.close()

def close(self):
if self._threadpool_executor:
self._threadpool_executor.shutdown()
self._threadpool_executor = None
if self._pool:
self._pool.close()
self._pool.join()
Expand All @@ -109,6 +114,12 @@ def pool(self):
self._pool = ThreadPool(self.pool_threads)
return self._pool

@property
def threadpool_executor(self):
if self._threadpool_executor is None:
self._threadpool_executor = ThreadPoolExecutor(max_workers=self.pool_threads)
return self._threadpool_executor

@property
def user_agent(self):
"""User agent for this API client"""
Expand Down Expand Up @@ -334,6 +345,7 @@ def call_api(
response_type: typing.Optional[typing.Tuple[typing.Any]] = None,
auth_settings: typing.Optional[typing.List[str]] = None,
async_req: typing.Optional[bool] = None,
async_threadpool_executor: typing.Optional[bool] = None,
_return_http_data_only: typing.Optional[bool] = None,
collection_formats: typing.Optional[typing.Dict[str, str]] = None,
_preload_content: bool = True,
Expand Down Expand Up @@ -394,6 +406,27 @@ def call_api(
If parameter async_req is False or missing,
then the method will return the response directly.
"""
if async_threadpool_executor:
return self.threadpool_executor.submit(
self.__call_api,
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
)

if not async_req:
return self.__call_api(
resource_path,
Expand Down Expand Up @@ -690,6 +723,7 @@ def __init__(self, settings=None, params_map=None, root_map=None, headers_map=No
self.params_map["all"].extend(
[
"async_req",
"async_threadpool_executor",
"_host_index",
"_preload_content",
"_request_timeout",
Expand All @@ -704,6 +738,7 @@ def __init__(self, settings=None, params_map=None, root_map=None, headers_map=No
self.openapi_types = root_map["openapi_types"]
extra_types = {
"async_req": (bool,),
"async_threadpool_executor": (bool, ),
"_host_index": (none_type, int),
"_preload_content": (bool,),
"_request_timeout": (none_type, float, (float,), [float], int, (int,), [int]),
Expand Down Expand Up @@ -853,6 +888,7 @@ def call_with_http_info(self, **kwargs):
response_type=self.settings["response_type"],
auth_settings=self.settings["auth"],
async_req=kwargs["async_req"],
async_threadpool_executor=kwargs.get("async_threadpool_executor", None),
_check_type=kwargs["_check_return_type"],
_return_http_data_only=kwargs["_return_http_data_only"],
_preload_content=kwargs["_preload_content"],
Expand Down
20 changes: 20 additions & 0 deletions pinecone/core/openapi/shared/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,23 @@ def host(self, value):
"""Fix base path."""
self._base_path = value
self.server_index = None

def __repr__(self):
attrs = [
f"host={self.host}",
f"api_key=***",
f"api_key_prefix={self.api_key_prefix}",
f"access_token={self.access_token}",
f"connection_pool_maxsize={self.connection_pool_maxsize}",
f"username={self.username}",
f"password={self.password}",
f"discard_unknown_keys={self.discard_unknown_keys}",
f"disabled_client_side_validations={self.disabled_client_side_validations}",
f"server_index={self.server_index}",
f"server_variables={self.server_variables}",
f"server_operation_index={self.server_operation_index}",
f"server_operation_variables={self.server_operation_variables}",
f"ssl_ca_cert={self.ssl_ca_cert}",

]
return f"Configuration({', '.join(attrs)})"
17 changes: 12 additions & 5 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tqdm.autonotebook import tqdm

import logging
import json
from typing import Union, List, Optional, Dict, Any

from pinecone.config import ConfigBuilder
Expand Down Expand Up @@ -34,7 +35,9 @@
from .features.bulk_import import ImportFeatureMixin
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults

from multiprocessing.pool import ApplyResult
from concurrent.futures import as_completed

from pinecone_plugin_interface import load_and_install as install_plugins

Expand Down Expand Up @@ -67,6 +70,7 @@
"_check_return_type",
"_host_index",
"async_req",
"async_threadpool_executor",
)


Expand Down Expand Up @@ -447,7 +451,7 @@ def query(
**kwargs,
)

if kwargs.get("async_req", False):
if kwargs.get("async_req", False) or kwargs.get("async_threadpool_executor", False):
return response
else:
return parse_query_response(response)
Expand Down Expand Up @@ -491,6 +495,7 @@ def _query(
("sparse_vector", sparse_vector),
]
)

response = self._vector_api.query(
QueryRequest(
**args_dict,
Expand Down Expand Up @@ -566,7 +571,7 @@ def query_namespaces(
aggregator = QueryResultsAggregator(top_k=overall_topk)

target_namespaces = set(namespaces) # dedup namespaces
async_results = [
async_futures = [
self.query(
vector=vector,
namespace=ns,
Expand All @@ -575,14 +580,16 @@ def query_namespaces(
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=True,
async_threadpool_executor=True,
_preload_content=False,
**kwargs,
)
for ns in target_namespaces
]

for result in async_results:
response = result.get()
for result in as_completed(async_futures):
raw_result = result.result()
response = json.loads(raw_result.data.decode("utf-8"))
aggregator.add_results(response)

final_results = aggregator.get_results()
Expand Down
33 changes: 32 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pytest-asyncio = "0.15.1"
pytest-cov = "2.10.1"
pytest-mock = "3.6.1"
pytest-timeout = "2.2.0"
pytest-benchmark = [
{ version = '5.0.0', python = ">=3.9,<4.0" }
]
urllib3_mock = "0.3.3"
responses = ">=0.8.1"
ddtrace = "^2.14.4"
Expand Down
45 changes: 45 additions & 0 deletions tests/perf/test_query_namespaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import time
import random
import pytest
from pinecone import Pinecone
from pinecone.grpc import PineconeGRPC

latencies = []


def call_n_threads(index):
query_vec = [random.random() for i in range(1024)]
start = time.time()
combined_results = index.query_namespaces(
vector=query_vec,
namespaces=["ns1", "ns2", "ns3", "ns4"],
include_values=False,
include_metadata=True,
filter={"publication_date": {"$eq": "Last3Months"}},
top_k=1000,
)
finish = time.time()
# print(f"Query took {finish-start} seconds")
latencies.append(finish - start)

return combined_results


class TestQueryNamespacesRest:
@pytest.mark.parametrize("n_threads", [4])
def test_query_namespaces_grpc(self, benchmark, n_threads):
pc = PineconeGRPC()
index = pc.Index(
host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io", pool_threads=n_threads
)
benchmark.pedantic(call_n_threads, (index,), rounds=10, warmup_rounds=1, iterations=5)

@pytest.mark.parametrize("n_threads", [4])
def test_query_namespaces_rest(self, benchmark, n_threads):
pc = Pinecone()
index = pc.Index(
host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io",
pool_threads=n_threads,
connection_pool_maxsize=20,
)
benchmark.pedantic(call_n_threads, (index,), rounds=10, warmup_rounds=1, iterations=5)
24 changes: 24 additions & 0 deletions tests/perf/test_query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import random
from pinecone.data.query_results_aggregator import QueryResultsAggregator


def fake_results(i):
matches = [
{"id": f"id{i}", "score": random.random(), "values": [random.random() for _ in range(768)]}
for _ in range(1000)
]
matches.sort(key=lambda x: x["score"], reverse=True)
return {"namespace": f"ns{i}", "matches": matches}


def aggregate_results(responses):
ag = QueryResultsAggregator(1000)
for response in responses:
ag.add_results(response)
return ag.get_results()


class TestQueryResultsAggregatorPerf:
def test_my_stuff(self, benchmark):
responses = [fake_results(i) for i in range(10)]
benchmark(aggregate_results, responses)

0 comments on commit 4a99468

Please sign in to comment.