From 70138110658850e15d1e7ecb78cb396fb1661e55 Mon Sep 17 00:00:00 2001 From: nerdai <92402603+nerdai@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:10:09 -0400 Subject: [PATCH] Feature async elasticsearch (#7613) --- CHANGELOG.md | 1 + .../data_modules/storage/vector_stores.md | 3 +- .../ElasticsearchIndexDemo.ipynb | 6 +- llama_index/utils.py | 20 ++- llama_index/vector_stores/__init__.py | 4 +- llama_index/vector_stores/elasticsearch.py | 151 ++++++++++++++---- tests/vector_stores/test_elasticsearch.py | 137 +++++++++++----- 7 files changed, 245 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 931148879b732..450071ad4ca8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features - Simplified portkey LLM interface (#7669) +- Added async operation support to `ElasticsearchStore` vector store (#7613) ### Bug Fixes / Nits - Avoid `NotImplementedError` for async langchain embeddings (#7668) diff --git a/docs/core_modules/data_modules/storage/vector_stores.md b/docs/core_modules/data_modules/storage/vector_stores.md index 19d595ddd7ceb..b21070e4b870d 100644 --- a/docs/core_modules/data_modules/storage/vector_stores.md +++ b/docs/core_modules/data_modules/storage/vector_stores.md @@ -15,7 +15,7 @@ We are actively adding more integrations and improving feature coverage for each | Vector Store | Type | Metadata Filtering | Hybrid Search | Delete | Store Documents | Async | | ------------------------ | ------------------- | ------------------ | ------------- | ------ | --------------- | ----- | -| Elasticsearch | self-hosted / cloud | ✓ | ✓ | ✓ | ✓ | | +| Elasticsearch | self-hosted / cloud | ✓ | ✓ | ✓ | ✓ | ✓ | | Pinecone | cloud | ✓ | ✓ | ✓ | ✓ | | | Weaviate | self-hosted / cloud | ✓ | ✓ | ✓ | ✓ | | | Postgres | self-hosted / cloud | ✓ | ✓ | ✓ | ✓ | ✓ | @@ -40,7 +40,6 @@ We are actively adding more integrations and improving feature coverage for each | FAISS | in-memory | | | | | | | ChatGPT Retrieval Plugin | aggregator | | | ✓ | ✓ | | | DocArray | aggregator | ✓ | | ✓ | ✓ | | -| Azure Cognitive Search | cloud | ✓ | ✓ | ✓ | ✓ | | For more details, see [Vector Store Integrations](/community/integrations/vector_stores.md). diff --git a/docs/examples/vector_stores/ElasticsearchIndexDemo.ipynb b/docs/examples/vector_stores/ElasticsearchIndexDemo.ipynb index a68d618748ebb..e4577f509efa4 100644 --- a/docs/examples/vector_stores/ElasticsearchIndexDemo.ipynb +++ b/docs/examples/vector_stores/ElasticsearchIndexDemo.ipynb @@ -10,6 +10,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "67837810", "metadata": {}, @@ -18,7 +19,7 @@ "\n", "[Signup](https://cloud.elastic.co/registration?utm_source=llama-index&utm_content=documentation) for a free trial.\n", "\n", - "Requires Elasticsearch 8.9.0 or higher." + "Requires Elasticsearch 8.9.0 or higher and AIOHTTP." ] }, { @@ -44,6 +45,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5b28b0ba", "metadata": {}, @@ -232,6 +234,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "056d6a4a", "metadata": {}, @@ -329,6 +332,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5b7436ea", "metadata": {}, diff --git a/llama_index/utils.py b/llama_index/utils.py index f1e0e720ed44c..52059d3813b5d 100644 --- a/llama_index/utils.py +++ b/llama_index/utils.py @@ -1,5 +1,6 @@ """General utils functions.""" +import asyncio import os import random import sys @@ -8,7 +9,7 @@ import uuid from contextlib import contextmanager from dataclasses import dataclass -from functools import partial +from functools import partial, wraps from itertools import islice from pathlib import Path from typing import ( @@ -280,6 +281,23 @@ def get_cache_dir() -> str: return str(path) +def add_sync_version(func: Any) -> Any: + """Decorator for adding sync version of an async function. The sync version + is added as a function attribute to the original function, func. + + Args: + func(Any): the async function for which a sync variant will be built. + """ + assert asyncio.iscoroutinefunction(func) + + @wraps(func) + def _wrapper(*args: Any, **kwds: Any) -> Any: + return asyncio.get_event_loop().run_until_complete(func(*args, **kwds)) + + func.sync = _wrapper + return func + + # Sample text from llama_index's readme SAMPLE_TEXT = """ Context diff --git a/llama_index/vector_stores/__init__.py b/llama_index/vector_stores/__init__.py index 2a169923450ed..24b359672b20f 100644 --- a/llama_index/vector_stores/__init__.py +++ b/llama_index/vector_stores/__init__.py @@ -12,7 +12,9 @@ DocArrayHnswVectorStore, DocArrayInMemoryVectorStore, ) -from llama_index.vector_stores.elasticsearch import ElasticsearchStore +from llama_index.vector_stores.elasticsearch import ( + ElasticsearchStore, +) from llama_index.vector_stores.epsilla import EpsillaVectorStore from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.lancedb import LanceDBVectorStore diff --git a/llama_index/vector_stores/elasticsearch.py b/llama_index/vector_stores/elasticsearch.py index 3c39ed60f472e..7133792984b7a 100644 --- a/llama_index/vector_stores/elasticsearch.py +++ b/llama_index/vector_stores/elasticsearch.py @@ -1,4 +1,6 @@ """Elasticsearch vector store.""" +import asyncio +import nest_asyncio import uuid from logging import getLogger from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast @@ -30,7 +32,7 @@ def _get_elasticsearch_client( username: Optional[str] = None, password: Optional[str] = None, ) -> Any: - """Get Elasticsearch client. + """Get AsyncElasticsearch client. Args: es_url: Elasticsearch URL. @@ -40,7 +42,7 @@ def _get_elasticsearch_client( password: Elasticsearch password. Returns: - Elasticsearch client. + AsyncElasticsearch client. Raises: ConnectionError: If Elasticsearch client cannot connect to Elasticsearch. @@ -78,14 +80,15 @@ def _get_elasticsearch_client( elif username and password: connection_params["basic_auth"] = (username, password) - es_client = elasticsearch.Elasticsearch(**connection_params) + sync_es_client = elasticsearch.Elasticsearch(**connection_params) + async_es_client = elasticsearch.AsyncElasticsearch(**connection_params) try: - es_client.info() + sync_es_client.info() # so don't have to 'await' to just get info except Exception as e: logger.error(f"Error connecting to Elasticsearch: {e}") raise e - return es_client + return async_es_client def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]: @@ -127,7 +130,7 @@ class ElasticsearchStore(VectorStore): Args: index_name: Name of the Elasticsearch index. - es_client: Optional. Pre-existing Elasticsearch client. + es_client: Optional. Pre-existing AsyncElasticsearch client. es_url: Optional. Elasticsearch URL. es_cloud_id: Optional. Elasticsearch cloud ID. es_api_key: Optional. Elasticsearch API key. @@ -141,7 +144,7 @@ class ElasticsearchStore(VectorStore): Defaults to "COSINE". Raises: - ConnectionError: If Elasticsearch client cannot connect to Elasticsearch. + ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch. ValueError: If neither es_client nor es_url nor es_cloud_id is provided. """ @@ -162,6 +165,7 @@ def __init__( batch_size: int = 200, distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE", ) -> None: + nest_asyncio.apply() self.index_name = index_name self.text_field = text_field self.vector_field = vector_field @@ -180,26 +184,26 @@ def __init__( ) else: raise ValueError( - """Either provide a pre-existing Elasticsearch connection, \ - or valid credentials for creating a new connection.""" + """Either provide a pre-existing AsyncElasticsearch or valid \ + credentials for creating a new connection.""" ) @property def client(self) -> Any: - """Get elasticsearch client.""" + """Get async elasticsearch client""" return self._client - def _create_index_if_not_exists( + async def _create_index_if_not_exists( self, index_name: str, dims_length: Optional[int] = None ) -> None: - """Create the Elasticsearch index if it doesn't already exist. + """Create the AsyncElasticsearch index if it doesn't already exist. Args: - index_name: Name of the Elasticsearch index to create. + index_name: Name of the AsyncElasticsearch index to create. dims_length: Length of the embedding vectors. """ - if self.client.indices.exists(index=index_name): + if await self.client.indices.exists(index=index_name): logger.debug(f"Index {index_name} already exists. Skipping creation.") else: @@ -244,7 +248,7 @@ def _create_index_if_not_exists( logger.debug( f"Creating index {index_name} with mappings {index_settings['mappings']}" # noqa: E501 ) - self.client.indices.create(index=index_name, **index_settings) + await self.client.indices.create(index=index_name, **index_settings) def add( self, @@ -264,16 +268,42 @@ def add( Returns: List of node IDs that were added to the index. + Raises: + ImportError: If elasticsearch['async'] python package is not installed. + BulkIndexError: If AsyncElasticsearch async_bulk indexing fails. + """ + return asyncio.get_event_loop().run_until_complete( + self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists) + ) + + async def async_add( + self, + nodes: List[BaseNode], + *, + create_index_if_not_exists: bool = True, + ) -> List[str]: + """Asynchronous method to add nodes to Elasticsearch index. + + Args: + nodes: List of nodes with embeddings. + create_index_if_not_exists: Optional. Whether to create + the AsyncElasticsearch index if it + doesn't already exist. + Defaults to True. + + Returns: + List of node IDs that were added to the index. + Raises: ImportError: If elasticsearch python package is not installed. - BulkIndexError: If Elasticsearch bulk indexing fails. + BulkIndexError: If AsyncElasticsearch async_bulk indexing fails. """ try: - from elasticsearch.helpers import BulkIndexError, bulk + from elasticsearch.helpers import BulkIndexError, async_bulk except ImportError: raise ImportError( - "Could not import elasticsearch python package. " - "Please install it with `pip install elasticsearch`." + "Could not import elasticsearch[async] python package. " + "Please install it with `pip install 'elasticsearch[async]'`." ) if len(nodes) == 0: @@ -281,7 +311,7 @@ def add( if create_index_if_not_exists: dims_length = len(nodes[0].get_embedding()) - self._create_index_if_not_exists( + await self._create_index_if_not_exists( index_name=self.index_name, dims_length=dims_length ) @@ -312,9 +342,13 @@ def add( requests.append(request) return_ids.append(_id) - bulk(self.client, requests, chunk_size=self.batch_size, refresh=True) + await async_bulk( + self.client, requests, chunk_size=self.batch_size, refresh=True + ) try: - success, failed = bulk(self.client, requests, stats_only=True, refresh=True) + success, failed = await async_bulk( + self.client, requests, stats_only=True, refresh=True + ) logger.debug(f"Added {success} and failed to add {failed} texts to index") logger.debug(f"added texts {ids} to index") @@ -336,14 +370,30 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: Raises: Exception: If Elasticsearch delete_by_query fails. """ + return asyncio.get_event_loop().run_until_complete( + self.adelete(ref_doc_id, **delete_kwargs) + ) + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Async delete node from Elasticsearch index. + + Args: + ref_doc_id: ID of the node to delete. + delete_kwargs: Optional. Additional arguments to + pass to AsyncElasticsearch delete_by_query. + + Raises: + Exception: If AsyncElasticsearch delete_by_query fails. + """ try: - res = self.client.delete_by_query( - index=self.index_name, - query={"term": {"metadata.ref_doc_id": ref_doc_id}}, - refresh=True, - **delete_kwargs, - ) + async with self.client as client: + res = await client.delete_by_query( + index=self.index_name, + query={"term": {"metadata.ref_doc_id": ref_doc_id}}, + refresh=True, + **delete_kwargs, + ) if res["deleted"] == 0: logger.warning(f"Could not find text {ref_doc_id} to delete") else: @@ -379,6 +429,38 @@ def query( Raises: Exception: If Elasticsearch query fails. + """ + return asyncio.get_event_loop().run_until_complete( + self.aquery(query, custom_query, es_filter, **kwargs) + ) + + async def aquery( + self, + query: VectorStoreQuery, + custom_query: Optional[ + Callable[[Dict, Union[VectorStoreQuery, None]], Dict] + ] = None, + es_filter: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VectorStoreQueryResult: + """Asynchronous query index for top k most similar nodes. + + Args: + query_embedding (VectorStoreQuery): query embedding + custom_query: Optional. custom query function that takes in the es query + body and returns a modified query body. + This can be used to add additional query + parameters to the AsyncElasticsearch query. + es_filter: Optional. AsyncElasticsearch filter to apply to the + query. If filter is provided in the query, + this filter will be ignored. + + Returns: + VectorStoreQueryResult: Result of the query. + + Raises: + Exception: If AsyncElasticsearch query fails. + """ query_embedding = cast(List[float], query.query_embedding) @@ -419,12 +501,13 @@ def query( es_query = custom_query(es_query, query) logger.debug(f"Calling custom_query, Query body now: {es_query}") - response = self.client.search( - index=self.index_name, - **es_query, - size=query.similarity_top_k, - _source={"excludes": [self.vector_field]}, - ) + async with self.client as client: + response = await client.search( + index=self.index_name, + **es_query, + size=query.similarity_top_k, + _source={"excludes": [self.vector_field]}, + ) top_k_nodes = [] top_k_ids = [] diff --git a/tests/vector_stores/test_elasticsearch.py b/tests/vector_stores/test_elasticsearch.py index 9649e858200dd..7425c527fae3e 100644 --- a/tests/vector_stores/test_elasticsearch.py +++ b/tests/vector_stores/test_elasticsearch.py @@ -127,20 +127,29 @@ def test_instance_creation(index_name: str, elasticsearch_connection: Dict) -> N @pytest.mark.skipif( elasticsearch_not_available, reason="elasticsearch is not available" ) -def test_add_to_es_and_query( +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_query( index_name: str, elasticsearch_connection: Dict, node_embeddings: List[TextNode], + use_async: bool, ) -> None: es_store = ElasticsearchStore( **elasticsearch_connection, index_name=index_name, distance_strategy="COSINE", ) - es_store.add(node_embeddings) - res = es_store.query( - VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) - ) + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) + ) assert res.nodes assert res.nodes[0].get_content() == "lorem ipsum" @@ -148,22 +157,37 @@ def test_add_to_es_and_query( @pytest.mark.skipif( elasticsearch_not_available, reason="elasticsearch is not available" ) -def test_add_to_es_and_text_query( +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_text_query( index_name: str, elasticsearch_connection: Dict, node_embeddings: List[TextNode], + use_async: bool, ) -> None: es_store = ElasticsearchStore( **elasticsearch_connection, index_name=index_name, distance_strategy="COSINE", ) - es_store.add(node_embeddings) - res = es_store.query( - VectorStoreQuery( - query_str="lorem", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=1 + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + VectorStoreQuery( + query_str="lorem", + mode=VectorStoreQueryMode.TEXT_SEARCH, + similarity_top_k=1, + ) + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + VectorStoreQuery( + query_str="lorem", + mode=VectorStoreQueryMode.TEXT_SEARCH, + similarity_top_k=1, + ) ) - ) assert res.nodes assert res.nodes[0].get_content() == "lorem ipsum" @@ -171,25 +195,39 @@ def test_add_to_es_and_text_query( @pytest.mark.skipif( elasticsearch_not_available, reason="elasticsearch is not available" ) -def test_add_to_es_and_hybrid_query( +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_and_hybrid_query( index_name: str, elasticsearch_connection: Dict, node_embeddings: List[TextNode], + use_async: bool, ) -> None: es_store = ElasticsearchStore( **elasticsearch_connection, index_name=index_name, distance_strategy="COSINE", ) - es_store.add(node_embeddings) - res = es_store.query( - VectorStoreQuery( - query_str="lorem", - query_embedding=[1.0, 0.0, 0.0], - mode=VectorStoreQueryMode.HYBRID, - similarity_top_k=1, + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + VectorStoreQuery( + query_str="lorem", + query_embedding=[1.0, 0.0, 0.0], + mode=VectorStoreQueryMode.HYBRID, + similarity_top_k=1, + ) + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + VectorStoreQuery( + query_str="lorem", + query_embedding=[1.0, 0.0, 0.0], + mode=VectorStoreQueryMode.HYBRID, + similarity_top_k=1, + ) ) - ) assert res.nodes assert res.nodes[0].get_content() == "lorem ipsum" @@ -197,27 +235,31 @@ def test_add_to_es_and_hybrid_query( @pytest.mark.skipif( elasticsearch_not_available, reason="elasticsearch is not available" ) -def test_add_to_es_query_with_filters( +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_query_with_filters( index_name: str, elasticsearch_connection: Dict, node_embeddings: List[TextNode], + use_async: bool, ) -> None: es_store = ElasticsearchStore( **elasticsearch_connection, index_name=index_name, distance_strategy="COSINE", ) - - es_store.add(node_embeddings) - filters = MetadataFilters( filters=[ExactMatchFilter(key="author", value="Stephen King")] ) q = VectorStoreQuery( query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10, filters=filters ) - - res = es_store.query(q) + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery(q) + else: + es_store.add(node_embeddings) + res = es_store.query(q) assert res.nodes assert len(res.nodes) == 1 assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0" @@ -226,22 +268,30 @@ def test_add_to_es_query_with_filters( @pytest.mark.skipif( elasticsearch_not_available, reason="elasticsearch is not available" ) -def test_add_to_es_query_with_es_filters( +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_query_with_es_filters( index_name: str, elasticsearch_connection: Dict, node_embeddings: List[TextNode], + use_async: bool, ) -> None: es_store = ElasticsearchStore( **elasticsearch_connection, index_name=index_name, distance_strategy="COSINE", ) - - es_store.add(node_embeddings) - q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10) - - res = es_store.query(q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}]) + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery( + q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}] + ) + else: + es_store.add(node_embeddings) + res = es_store.query( + q, es_filter=[{"wildcard": {"metadata.author": "stephe*"}}] + ) assert res.nodes assert len(res.nodes) == 1 assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0" @@ -250,26 +300,37 @@ def test_add_to_es_query_with_es_filters( @pytest.mark.skipif( elasticsearch_not_available, reason="elasticsearch is not available" ) -def test_add_to_es_query_and_delete( +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_es_query_and_delete( index_name: str, elasticsearch_connection: Dict, node_embeddings: List[TextNode], + use_async: bool, ) -> None: es_store = ElasticsearchStore( **elasticsearch_connection, index_name=index_name, distance_strategy="COSINE", ) - q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1) - es_store.add(node_embeddings) - res = es_store.query(q) + if use_async: + await es_store.async_add(node_embeddings) + res = await es_store.aquery(q) + else: + es_store.add(node_embeddings) + res = es_store.query(q) assert res.nodes assert len(res.nodes) == 1 assert res.nodes[0].node_id == "c330d77f-90bd-4c51-9ed2-57d8d693b3b0" - es_store.delete("test-0") - res = es_store.query(q) + + if use_async: + await es_store.adelete("test-0") + res = await es_store.aquery(q) + else: + es_store.delete("test-0") + res = es_store.query(q) assert res.nodes assert len(res.nodes) == 1 assert res.nodes[0].node_id == "c3d1e1dd-8fb4-4b8f-b7ea-7fa96038d39d"