Skip to content

Commit

Permalink
Feature async elasticsearch (run-llama#7613)
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai authored Sep 14, 2023
1 parent 6b545c6 commit 7013811
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features
- Simplified portkey LLM interface (#7669)
- Added async operation support to `ElasticsearchStore` vector store (#7613)

### Bug Fixes / Nits
- Avoid `NotImplementedError` for async langchain embeddings (#7668)
Expand Down
3 changes: 1 addition & 2 deletions docs/core_modules/data_modules/storage/vector_stores.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ We are actively adding more integrations and improving feature coverage for each

| Vector Store | Type | Metadata Filtering | Hybrid Search | Delete | Store Documents | Async |
| ------------------------ | ------------------- | ------------------ | ------------- | ------ | --------------- | ----- |
| Elasticsearch | self-hosted / cloud ||||| |
| Elasticsearch | self-hosted / cloud ||||| |
| Pinecone | cloud ||||| |
| Weaviate | self-hosted / cloud ||||| |
| Postgres | self-hosted / cloud ||||||
Expand All @@ -40,7 +40,6 @@ We are actively adding more integrations and improving feature coverage for each
| FAISS | in-memory | | | | | |
| ChatGPT Retrieval Plugin | aggregator | | ||| |
| DocArray | aggregator || ||| |
| Azure Cognitive Search | cloud ||||| |

For more details, see [Vector Store Integrations](/community/integrations/vector_stores.md).

Expand Down
6 changes: 5 additions & 1 deletion docs/examples/vector_stores/ElasticsearchIndexDemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "67837810",
"metadata": {},
Expand All @@ -18,7 +19,7 @@
"\n",
"[Signup](https://cloud.elastic.co/registration?utm_source=llama-index&utm_content=documentation) for a free trial.\n",
"\n",
"Requires Elasticsearch 8.9.0 or higher."
"Requires Elasticsearch 8.9.0 or higher and AIOHTTP."
]
},
{
Expand All @@ -44,6 +45,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5b28b0ba",
"metadata": {},
Expand Down Expand Up @@ -232,6 +234,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "056d6a4a",
"metadata": {},
Expand Down Expand Up @@ -329,6 +332,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5b7436ea",
"metadata": {},
Expand Down
20 changes: 19 additions & 1 deletion llama_index/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""General utils functions."""

import asyncio
import os
import random
import sys
Expand All @@ -8,7 +9,7 @@
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from functools import partial, wraps
from itertools import islice
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -280,6 +281,23 @@ def get_cache_dir() -> str:
return str(path)


def add_sync_version(func: Any) -> Any:
"""Decorator for adding sync version of an async function. The sync version
is added as a function attribute to the original function, func.
Args:
func(Any): the async function for which a sync variant will be built.
"""
assert asyncio.iscoroutinefunction(func)

@wraps(func)
def _wrapper(*args: Any, **kwds: Any) -> Any:
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))

func.sync = _wrapper
return func


# Sample text from llama_index's readme
SAMPLE_TEXT = """
Context
Expand Down
4 changes: 3 additions & 1 deletion llama_index/vector_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
DocArrayHnswVectorStore,
DocArrayInMemoryVectorStore,
)
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.elasticsearch import (
ElasticsearchStore,
)
from llama_index.vector_stores.epsilla import EpsillaVectorStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.lancedb import LanceDBVectorStore
Expand Down
151 changes: 117 additions & 34 deletions llama_index/vector_stores/elasticsearch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Elasticsearch vector store."""
import asyncio
import nest_asyncio
import uuid
from logging import getLogger
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
Expand Down Expand Up @@ -30,7 +32,7 @@ def _get_elasticsearch_client(
username: Optional[str] = None,
password: Optional[str] = None,
) -> Any:
"""Get Elasticsearch client.
"""Get AsyncElasticsearch client.
Args:
es_url: Elasticsearch URL.
Expand All @@ -40,7 +42,7 @@ def _get_elasticsearch_client(
password: Elasticsearch password.
Returns:
Elasticsearch client.
AsyncElasticsearch client.
Raises:
ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
Expand Down Expand Up @@ -78,14 +80,15 @@ def _get_elasticsearch_client(
elif username and password:
connection_params["basic_auth"] = (username, password)

es_client = elasticsearch.Elasticsearch(**connection_params)
sync_es_client = elasticsearch.Elasticsearch(**connection_params)
async_es_client = elasticsearch.AsyncElasticsearch(**connection_params)
try:
es_client.info()
sync_es_client.info() # so don't have to 'await' to just get info
except Exception as e:
logger.error(f"Error connecting to Elasticsearch: {e}")
raise e

return es_client
return async_es_client


def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
Expand Down Expand Up @@ -127,7 +130,7 @@ class ElasticsearchStore(VectorStore):
Args:
index_name: Name of the Elasticsearch index.
es_client: Optional. Pre-existing Elasticsearch client.
es_client: Optional. Pre-existing AsyncElasticsearch client.
es_url: Optional. Elasticsearch URL.
es_cloud_id: Optional. Elasticsearch cloud ID.
es_api_key: Optional. Elasticsearch API key.
Expand All @@ -141,7 +144,7 @@ class ElasticsearchStore(VectorStore):
Defaults to "COSINE".
Raises:
ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch.
ValueError: If neither es_client nor es_url nor es_cloud_id is provided.
"""
Expand All @@ -162,6 +165,7 @@ def __init__(
batch_size: int = 200,
distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE",
) -> None:
nest_asyncio.apply()
self.index_name = index_name
self.text_field = text_field
self.vector_field = vector_field
Expand All @@ -180,26 +184,26 @@ def __init__(
)
else:
raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection."""
"""Either provide a pre-existing AsyncElasticsearch or valid \
credentials for creating a new connection."""
)

@property
def client(self) -> Any:
"""Get elasticsearch client."""
"""Get async elasticsearch client"""
return self._client

def _create_index_if_not_exists(
async def _create_index_if_not_exists(
self, index_name: str, dims_length: Optional[int] = None
) -> None:
"""Create the Elasticsearch index if it doesn't already exist.
"""Create the AsyncElasticsearch index if it doesn't already exist.
Args:
index_name: Name of the Elasticsearch index to create.
index_name: Name of the AsyncElasticsearch index to create.
dims_length: Length of the embedding vectors.
"""

if self.client.indices.exists(index=index_name):
if await self.client.indices.exists(index=index_name):
logger.debug(f"Index {index_name} already exists. Skipping creation.")

else:
Expand Down Expand Up @@ -244,7 +248,7 @@ def _create_index_if_not_exists(
logger.debug(
f"Creating index {index_name} with mappings {index_settings['mappings']}" # noqa: E501
)
self.client.indices.create(index=index_name, **index_settings)
await self.client.indices.create(index=index_name, **index_settings)

def add(
self,
Expand All @@ -264,24 +268,50 @@ def add(
Returns:
List of node IDs that were added to the index.
Raises:
ImportError: If elasticsearch['async'] python package is not installed.
BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
"""
return asyncio.get_event_loop().run_until_complete(
self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists)
)

async def async_add(
self,
nodes: List[BaseNode],
*,
create_index_if_not_exists: bool = True,
) -> List[str]:
"""Asynchronous method to add nodes to Elasticsearch index.
Args:
nodes: List of nodes with embeddings.
create_index_if_not_exists: Optional. Whether to create
the AsyncElasticsearch index if it
doesn't already exist.
Defaults to True.
Returns:
List of node IDs that were added to the index.
Raises:
ImportError: If elasticsearch python package is not installed.
BulkIndexError: If Elasticsearch bulk indexing fails.
BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
"""
try:
from elasticsearch.helpers import BulkIndexError, bulk
from elasticsearch.helpers import BulkIndexError, async_bulk
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
"Could not import elasticsearch[async] python package. "
"Please install it with `pip install 'elasticsearch[async]'`."
)

if len(nodes) == 0:
return []

if create_index_if_not_exists:
dims_length = len(nodes[0].get_embedding())
self._create_index_if_not_exists(
await self._create_index_if_not_exists(
index_name=self.index_name, dims_length=dims_length
)

Expand Down Expand Up @@ -312,9 +342,13 @@ def add(
requests.append(request)
return_ids.append(_id)

bulk(self.client, requests, chunk_size=self.batch_size, refresh=True)
await async_bulk(
self.client, requests, chunk_size=self.batch_size, refresh=True
)
try:
success, failed = bulk(self.client, requests, stats_only=True, refresh=True)
success, failed = await async_bulk(
self.client, requests, stats_only=True, refresh=True
)
logger.debug(f"Added {success} and failed to add {failed} texts to index")

logger.debug(f"added texts {ids} to index")
Expand All @@ -336,14 +370,30 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
Raises:
Exception: If Elasticsearch delete_by_query fails.
"""
return asyncio.get_event_loop().run_until_complete(
self.adelete(ref_doc_id, **delete_kwargs)
)

async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""Async delete node from Elasticsearch index.
Args:
ref_doc_id: ID of the node to delete.
delete_kwargs: Optional. Additional arguments to
pass to AsyncElasticsearch delete_by_query.
Raises:
Exception: If AsyncElasticsearch delete_by_query fails.
"""

try:
res = self.client.delete_by_query(
index=self.index_name,
query={"term": {"metadata.ref_doc_id": ref_doc_id}},
refresh=True,
**delete_kwargs,
)
async with self.client as client:
res = await client.delete_by_query(
index=self.index_name,
query={"term": {"metadata.ref_doc_id": ref_doc_id}},
refresh=True,
**delete_kwargs,
)
if res["deleted"] == 0:
logger.warning(f"Could not find text {ref_doc_id} to delete")
else:
Expand Down Expand Up @@ -379,6 +429,38 @@ def query(
Raises:
Exception: If Elasticsearch query fails.
"""
return asyncio.get_event_loop().run_until_complete(
self.aquery(query, custom_query, es_filter, **kwargs)
)

async def aquery(
self,
query: VectorStoreQuery,
custom_query: Optional[
Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
] = None,
es_filter: Optional[List[Dict]] = None,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Asynchronous query index for top k most similar nodes.
Args:
query_embedding (VectorStoreQuery): query embedding
custom_query: Optional. custom query function that takes in the es query
body and returns a modified query body.
This can be used to add additional query
parameters to the AsyncElasticsearch query.
es_filter: Optional. AsyncElasticsearch filter to apply to the
query. If filter is provided in the query,
this filter will be ignored.
Returns:
VectorStoreQueryResult: Result of the query.
Raises:
Exception: If AsyncElasticsearch query fails.
"""
query_embedding = cast(List[float], query.query_embedding)

Expand Down Expand Up @@ -419,12 +501,13 @@ def query(
es_query = custom_query(es_query, query)
logger.debug(f"Calling custom_query, Query body now: {es_query}")

response = self.client.search(
index=self.index_name,
**es_query,
size=query.similarity_top_k,
_source={"excludes": [self.vector_field]},
)
async with self.client as client:
response = await client.search(
index=self.index_name,
**es_query,
size=query.similarity_top_k,
_source={"excludes": [self.vector_field]},
)

top_k_nodes = []
top_k_ids = []
Expand Down
Loading

0 comments on commit 7013811

Please sign in to comment.