Skip to content

Commit

Permalink
Update ElasticsearchStore to return normalized similarities (run-llam…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenwebel authored Sep 25, 2023
1 parent 33b49de commit c8f5d84
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 45 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

## Unreleased

### New Features
- Added `Konko` LLM support (#7775)

### Bug Fixes / Nits
- Normalize scores returned from ElasticSearch vector store (#7792)
- Fixed `refresh_ref_docs()` bug with order of operations (#7664)

## [0.8.33] - 2023-09-25
Expand Down
11 changes: 10 additions & 1 deletion llama_index/vector_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from logging import getLogger
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast

import numpy as np

from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
MetadataFilters,
Expand Down Expand Up @@ -124,6 +126,11 @@ def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any
return {"bool": {"must": operands}}


def _to_llama_similarities(scores: List[float]) -> List[float]:
scores_to_norm: np.ndarray = np.array(scores)
return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist()


class ElasticsearchStore(VectorStore):
"""Elasticsearch vector store.
Expand Down Expand Up @@ -546,5 +553,7 @@ async def aquery(
top_k_ids.append(node_id)
top_k_scores.append(hit["_score"])
return VectorStoreQueryResult(
nodes=top_k_nodes, ids=top_k_ids, similarities=top_k_scores
nodes=top_k_nodes,
ids=top_k_ids,
similarities=_to_llama_similarities(top_k_scores),
)
179 changes: 135 additions & 44 deletions tests/vector_stores/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import Any, Dict, Generator, List, Union

import pandas as pd
import pytest

from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode
Expand All @@ -25,8 +26,12 @@
es_client.info()

elasticsearch_not_available = False

es_license = es_client.license.get()
basic_license: bool = "basic" == es_license["license"]["type"]
except (ImportError, Exception):
elasticsearch_not_available = True
basic_license = True


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -110,6 +115,39 @@ def node_embeddings() -> List[TextNode]:
},
embedding=[0.0, 0.0, 1.0],
),
TextNode(
text="I was taught that the way of progress was neither swift nor easy.",
id_="0b31ae71-b797-4e88-8495-031371a7752e",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-3")},
metadate={
"author": "Marie Curie",
},
embedding=[0.0, 0.0, 0.9],
),
TextNode(
text=(
"The important thing is not to stop questioning."
+ " Curiosity has its own reason for existing."
),
id_="bd2e080b-159a-4030-acc3-d98afd2ba49b",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-4")},
metadate={
"author": "Albert Einstein",
},
embedding=[0.0, 0.0, 0.5],
),
TextNode(
text=(
"I am no bird; and no net ensnares me;"
+ " I am a free human being with an independent will."
),
id_="f658de3b-8cef-4d1c-8bed-9a263c907251",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="text-5")},
metadate={
"author": "Charlotte Bronte",
},
embedding=[0.0, 0.0, 0.3],
),
]


Expand All @@ -124,22 +162,25 @@ def test_instance_creation(index_name: str, elasticsearch_connection: Dict) -> N
assert isinstance(es_store, ElasticsearchStore)


@pytest.fixture(scope="function")
def es_store(index_name: str, elasticsearch_connection: Dict) -> ElasticsearchStore:
return ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="EUCLIDEAN_DISTANCE",
)


@pytest.mark.skipif(
elasticsearch_not_available, reason="elasticsearch is not available"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_and_query(
index_name: str,
elasticsearch_connection: Dict,
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
es_store = ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="COSINE",
)
if use_async:
await es_store.async_add(node_embeddings)
res = await es_store.aquery(
Expand All @@ -160,16 +201,10 @@ async def test_add_to_es_and_query(
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_and_text_query(
index_name: str,
elasticsearch_connection: Dict,
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
es_store = ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="COSINE",
)
if use_async:
await es_store.async_add(node_embeddings)
res = await es_store.aquery(
Expand All @@ -193,21 +228,17 @@ async def test_add_to_es_and_text_query(


@pytest.mark.skipif(
elasticsearch_not_available, reason="elasticsearch is not available"
elasticsearch_not_available,
basic_license,
reason="elasticsearch is not available or license is basic",
)
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_and_hybrid_query(
index_name: str,
elasticsearch_connection: Dict,
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
es_store = ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="COSINE",
)
if use_async:
await es_store.async_add(node_embeddings)
res = await es_store.aquery(
Expand Down Expand Up @@ -238,16 +269,10 @@ async def test_add_to_es_and_hybrid_query(
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_query_with_filters(
index_name: str,
elasticsearch_connection: Dict,
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
es_store = ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="COSINE",
)
filters = MetadataFilters(
filters=[ExactMatchFilter(key="author", value="Stephen King")]
)
Expand All @@ -271,16 +296,10 @@ async def test_add_to_es_query_with_filters(
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_query_with_es_filters(
index_name: str,
elasticsearch_connection: Dict,
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
es_store = ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="COSINE",
)
q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=10)
if use_async:
await es_store.async_add(node_embeddings)
Expand All @@ -303,16 +322,10 @@ async def test_add_to_es_query_with_es_filters(
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_query_and_delete(
index_name: str,
elasticsearch_connection: Dict,
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
es_store = ElasticsearchStore(
**elasticsearch_connection,
index_name=index_name,
distance_strategy="COSINE",
)
q = VectorStoreQuery(query_embedding=[1.0, 0.0, 0.0], similarity_top_k=1)

if use_async:
Expand All @@ -333,4 +346,82 @@ async def test_add_to_es_query_and_delete(
res = es_store.query(q)
assert res.nodes
assert len(res.nodes) == 1
assert res.nodes[0].node_id == "c3d1e1dd-8fb4-4b8f-b7ea-7fa96038d39d"
assert res.nodes[0].node_id == "f658de3b-8cef-4d1c-8bed-9a263c907251"


@pytest.mark.skipif(
elasticsearch_not_available, reason="elasticsearch is not available"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_and_embed_query_ranked(
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
einstein_bronte_curie = [
"bd2e080b-159a-4030-acc3-d98afd2ba49b",
"f658de3b-8cef-4d1c-8bed-9a263c907251",
"0b31ae71-b797-4e88-8495-031371a7752e",
]
query_get_1_first = VectorStoreQuery(
query_embedding=[0.0, 0.0, 0.5], similarity_top_k=3
)
await check_top_match(
es_store, node_embeddings, use_async, query_get_1_first, *einstein_bronte_curie
)


@pytest.mark.skipif(
elasticsearch_not_available, reason="elasticsearch is not available"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_es_and_text_query_ranked(
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
node1 = "0b31ae71-b797-4e88-8495-031371a7752e"
node2 = "f658de3b-8cef-4d1c-8bed-9a263c907251"

query_get_1_first = VectorStoreQuery(
query_str="I was", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=2
)
await check_top_match(
es_store, node_embeddings, use_async, query_get_1_first, node1, node2
)

query_get_2_first = VectorStoreQuery(
query_str="I am", mode=VectorStoreQueryMode.TEXT_SEARCH, similarity_top_k=2
)
await check_top_match(
es_store, node_embeddings, use_async, query_get_2_first, node2, node1
)


async def check_top_match(
es_store: ElasticsearchStore,
node_embeddings: List[TextNode],
use_async: bool,
query: VectorStoreQuery,
*expected_nodes: str,
) -> None:
if use_async:
await es_store.async_add(node_embeddings)
res = await es_store.aquery(query)
else:
es_store.add(node_embeddings)
res = es_store.query(query)
assert res.nodes
# test the nodes are return in the expected order
for i, node in enumerate(expected_nodes):
assert res.nodes[i].node_id == node
# test the returned order is in descending order w.r.t. similarities
# test similarities are normalized (0, 1)
df = pd.DataFrame({"node": res.nodes, "sim": res.similarities, "id": res.ids})
sorted_by_sim = df.sort_values(by="sim", ascending=False)
for idx, item in enumerate(sorted_by_sim.itertuples()):
res_node = res.nodes[idx]
assert res_node.node_id == item.id
assert 0 <= item.sim <= 1

0 comments on commit c8f5d84

Please sign in to comment.