Skip to content

Commit

Permalink
fix(api): search vectorstore using only last message (#939)
Browse files Browse the repository at this point in the history
* Search VectorStore using only last message
* tests: add assertion for correct RAG results
  • Loading branch information
alekst23 authored Sep 4, 2024
1 parent 985642a commit 8a1d61e
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 36 deletions.
6 changes: 2 additions & 4 deletions src/leapfrogai_api/backend/rag/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ async def index_file(self, vector_store_id: str, file_id: str) -> VectorStoreFil
crud_vector_store_file = CRUDVectorStoreFile(db=self.db)
crud_vector_store = CRUDVectorStore(db=self.db)

if await crud_vector_store_file.get(
if file_existing := await crud_vector_store_file.get(
filters=FilterVectorStoreFile(vector_store_id=vector_store_id, id=file_id)
):
logger.error("File already indexed: %s", file_id)
raise FileAlreadyIndexedError("File already indexed")

return file_existing
if not (
await crud_vector_store.get(filters=FilterVectorStore(id=vector_store_id))
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ def sort_by_created_at(msg: Message):
)
)

first_message: ChatMessage = chat_thread_messages[0]

# Holds the converted thread's messages, this will be built up with a series of push operations
chat_messages: list[ChatMessage] = []

Expand All @@ -250,24 +248,24 @@ def sort_by_created_at(msg: Message):
for message in chat_thread_messages:
chat_messages.append(message)

use_rag: bool = self.can_use_rag(tool_resources)
# 4 - The RAG results are appended behind the user's query
if self.can_use_rag(tool_resources):
rag_message: str = "Here are relevant docs needed to reply:\n"

rag_message: str = "Here are relevant docs needed to reply:\n"
query_message: ChatMessage = chat_thread_messages[-1]

# 4 - The RAG results are appended behind the user's query
file_ids: set[str] = set()
if use_rag:
query_service = QueryService(db=session)
file_search: BetaThreadToolResourcesFileSearch = cast(
BetaThreadToolResourcesFileSearch, tool_resources.file_search
)
vector_store_ids: list[str] = cast(list[str], file_search.vector_store_ids)

vector_store_ids: list[str] = cast(list[str], file_search.vector_store_ids)
file_ids: set[str] = set()
for vector_store_id in vector_store_ids:
rag_results_raw: SingleAPIResponse[
SearchResponse
] = await query_service.query_rag(
query=first_message.content,
query=query_message.content,
vector_store_id=vector_store_id,
)
rag_responses: SearchResponse = SearchResponse(
Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_conformance_assistants.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import pytest
from openai.types.beta.assistant import Assistant

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_assistant(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")

assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
model=config["model"],
model=config.model,
tools=[{"type": "file_search"}],
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conformance/test_conformance_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from openai.types.beta.thread import Thread
from openai.types.beta.threads import Message, TextContentBlock, Text

from .utils import client_config_factory
from ..utils.client import client_config_factory


def make_mock_message_object(role, message_text):
Expand Down Expand Up @@ -37,7 +37,7 @@ def make_mock_message_simple(role, message_text):
)
def test_thread(client_name, test_messages):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create(messages=test_messages)

Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_conformance_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openai.types.beta.threads.message import Message
import re

from .utils import client_config_factory, text_file_path
from ..utils.client import client_config_factory, text_file_path


def make_vector_store_with_file(client):
Expand Down Expand Up @@ -46,10 +46,10 @@ def validate_annotation_format(annotation):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_thread_file_annotations(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = make_vector_store_with_file(client)
assistant = make_test_assistant(client, config["model"], vector_store.id)
assistant = make_test_assistant(client, config.model, vector_store.id)
thread = client.beta.threads.create()

client.beta.threads.messages.create(
Expand Down
8 changes: 4 additions & 4 deletions tests/conformance/test_conformance_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from openai.types.beta.vector_store import VectorStore
from openai.types.beta.vector_store_deleted import VectorStoreDeleted

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_create(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")

Expand All @@ -19,7 +19,7 @@ def test_vector_store_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_list(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

client.beta.vector_stores.create(name="Test data")

Expand All @@ -34,7 +34,7 @@ def test_vector_store_list(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_delete(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")

Expand Down
8 changes: 4 additions & 4 deletions tests/conformance/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
)
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile

from .utils import client_config_factory, text_file_path
from ..utils.client import client_config_factory, text_file_path


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")
with open(text_file_path(), "rb") as file:
Expand All @@ -27,7 +27,7 @@ def test_file_upload(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload_batches(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")

Expand All @@ -43,7 +43,7 @@ def test_file_upload_batches(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_delete(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")
with open(text_file_path(), "rb") as file:
Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from openai.types.beta.threads.message import Message

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_message_create(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
Expand All @@ -23,7 +23,7 @@ def test_message_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_message_list(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create()
client.beta.threads.messages.create(
Expand Down
1 change: 1 addition & 0 deletions tests/data/test_rag_1.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats are fascinating creatures known for their agility and independence. Domestic cats, descendants of African wildcats, have been companions to humans for over 4,000 years. They possess excellent night vision, allowing them to see in light levels six times lower than what humans need.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
With approximately 32 muscles in each ear, cats have an exceptional sense of hearing and can rotate their ears 180 degrees. Their unique grooming behavior helps to maintain their coat and regulate body temperature. Additionally, cats have a specialized collarbone that enables them to always land on their feet, a skill they demonstrate with remarkable precision.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats are remarkable animals with a range of intriguing characteristics. They have a highly developed sense of smell, with about 50 to 80 million scent receptors in their noses, compared to humans who have only around 5 million.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.4.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats can make over 100 different sounds, including purring, meowing, and hissing, which they use to communicate with their humans and other animals. Their whiskers are highly sensitive and can detect changes in their environment, helping them navigate and judge spaces.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.5.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats also have a unique grooming behavior, with their tongues covered in tiny, hook-shaped structures called papillae that help them clean their fur and remove loose hair. Additionally, cats sleep for about 12 to 16 hours a day, making them one of the sleepiest of all domestic animals.
1 change: 1 addition & 0 deletions tests/data/test_rag_2.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
There is only one piece of fruit in the fridge. It is an orange, and located on the bottom shelf, behind the pot of stew.
77 changes: 77 additions & 0 deletions tests/integration/api/test_rag_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from pathlib import Path
from openai.types.beta.threads.text import Text

from ...utils.client import client_config_factory


def make_test_assistant(client, model, vector_store_id):
assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
model=model,
tools=[{"type": "file_search"}],
tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}},
)
return assistant


def make_test_run(client, assistant, thread):
run = client.beta.threads.runs.create_and_poll(
assistant_id=assistant.id, thread_id=thread.id
)
return run


def test_rag_needle_haystack():
config = client_config_factory("leapfrogai")
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")
file_path = "../../data"
file_names = [
"test_rag_1.1.txt",
"test_rag_1.2.txt",
"test_rag_1.3.txt",
"test_rag_1.4.txt",
"test_rag_1.5.txt",
"test_rag_2.1.txt",
]
vector_store_files = []
for file_name in file_names:
with open(
f"{Path(os.path.dirname(__file__))}/{file_path}/{file_name}", "rb"
) as file:
vector_store_files.append(
client.beta.vector_stores.files.upload(
vector_store_id=vector_store.id, file=file
)
)

assistant = make_test_assistant(client, config.model, vector_store.id)
thread = client.beta.threads.create()

client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content="Tell me about cats.",
)
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content="There is one piece of fruit in the fridge. What is it and where is it located?",
)
run = make_test_run(client, assistant, thread)

messages = client.beta.threads.messages.list(
thread_id=thread.id, run_id=run.id
).data

# Get the response content from the last message
message_content = messages[-1].content[0].text
assert isinstance(message_content, Text)
assert "orange" in message_content.value
assert len(message_content.annotations) > 0

for a in message_content.annotations:
print(a.text)
19 changes: 15 additions & 4 deletions tests/conformance/utils.py → tests/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,23 @@ def openai_client():
def leapfrogai_client():
return OpenAI(
base_url=os.getenv("LEAPFROGAI_API_URL"),
api_key=os.getenv("LEAPFROGAI_API_KEY"),
api_key=os.getenv("SUPABASE_USER_JWT"),
)


def client_config_factory(client_name):
class ClientConfig:
client: OpenAI
model: str

def __init__(self, client: OpenAI, model: str):
self.client = client
self.model = model


def client_config_factory(client_name) -> ClientConfig:
if client_name == "openai":
return dict(client=openai_client(), model=OPENAI_MODEL)
return ClientConfig(client=openai_client(), model=OPENAI_MODEL)
elif client_name == "leapfrogai":
return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
return ClientConfig(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
else:
raise ValueError(f"Unknown client name: {client_name}")

0 comments on commit 8a1d61e

Please sign in to comment.