Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(api): search vectorstore using only last message #939

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/leapfrogai_api/backend/rag/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ async def index_file(self, vector_store_id: str, file_id: str) -> VectorStoreFil
crud_vector_store_file = CRUDVectorStoreFile(db=self.db)
crud_vector_store = CRUDVectorStore(db=self.db)

if await crud_vector_store_file.get(
if file_existing := await crud_vector_store_file.get(
YrrepNoj marked this conversation as resolved.
Show resolved Hide resolved
filters=FilterVectorStoreFile(vector_store_id=vector_store_id, id=file_id)
):
logger.error("File already indexed: %s", file_id)
raise FileAlreadyIndexedError("File already indexed")

return file_existing
if not (
await crud_vector_store.get(filters=FilterVectorStore(id=vector_store_id))
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ def sort_by_created_at(msg: Message):
)
)

first_message: ChatMessage = chat_thread_messages[0]

# Holds the converted thread's messages, this will be built up with a series of push operations
chat_messages: list[ChatMessage] = []

Expand All @@ -250,24 +248,24 @@ def sort_by_created_at(msg: Message):
for message in chat_thread_messages:
chat_messages.append(message)

use_rag: bool = self.can_use_rag(tool_resources)
# 4 - The RAG results are appended behind the user's query
if self.can_use_rag(tool_resources):
rag_message: str = "Here are relevant docs needed to reply:\n"

rag_message: str = "Here are relevant docs needed to reply:\n"
query_message: ChatMessage = chat_thread_messages[-1]

# 4 - The RAG results are appended behind the user's query
file_ids: set[str] = set()
if use_rag:
query_service = QueryService(db=session)
file_search: BetaThreadToolResourcesFileSearch = cast(
BetaThreadToolResourcesFileSearch, tool_resources.file_search
)
vector_store_ids: list[str] = cast(list[str], file_search.vector_store_ids)

vector_store_ids: list[str] = cast(list[str], file_search.vector_store_ids)
file_ids: set[str] = set()
for vector_store_id in vector_store_ids:
rag_results_raw: SingleAPIResponse[
SearchResponse
] = await query_service.query_rag(
query=first_message.content,
query=query_message.content,
vector_store_id=vector_store_id,
)
rag_responses: SearchResponse = SearchResponse(
Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_conformance_assistants.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import pytest
from openai.types.beta.assistant import Assistant

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_assistant(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")

assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
model=config["model"],
model=config.model,
tools=[{"type": "file_search"}],
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conformance/test_conformance_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from openai.types.beta.thread import Thread
from openai.types.beta.threads import Message, TextContentBlock, Text

from .utils import client_config_factory
from ..utils.client import client_config_factory


def make_mock_message_object(role, message_text):
Expand Down Expand Up @@ -37,7 +37,7 @@ def make_mock_message_simple(role, message_text):
)
def test_thread(client_name, test_messages):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create(messages=test_messages)

Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_conformance_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openai.types.beta.threads.message import Message
import re

from .utils import client_config_factory, text_file_path
from ..utils.client import client_config_factory, text_file_path


def make_vector_store_with_file(client):
Expand Down Expand Up @@ -46,10 +46,10 @@ def validate_annotation_format(annotation):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_thread_file_annotations(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = make_vector_store_with_file(client)
assistant = make_test_assistant(client, config["model"], vector_store.id)
assistant = make_test_assistant(client, config.model, vector_store.id)
thread = client.beta.threads.create()

client.beta.threads.messages.create(
Expand Down
8 changes: 4 additions & 4 deletions tests/conformance/test_conformance_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from openai.types.beta.vector_store import VectorStore
from openai.types.beta.vector_store_deleted import VectorStoreDeleted

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_create(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")

Expand All @@ -19,7 +19,7 @@ def test_vector_store_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_list(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

client.beta.vector_stores.create(name="Test data")

Expand All @@ -34,7 +34,7 @@ def test_vector_store_list(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_delete(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")

Expand Down
8 changes: 4 additions & 4 deletions tests/conformance/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
)
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile

from .utils import client_config_factory, text_file_path
from ..utils.client import client_config_factory, text_file_path


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")
with open(text_file_path(), "rb") as file:
Expand All @@ -27,7 +27,7 @@ def test_file_upload(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload_batches(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")

Expand All @@ -43,7 +43,7 @@ def test_file_upload_batches(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_delete(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")
with open(text_file_path(), "rb") as file:
Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from openai.types.beta.threads.message import Message

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_message_create(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
Expand All @@ -23,7 +23,7 @@ def test_message_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_message_list(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create()
client.beta.threads.messages.create(
Expand Down
1 change: 1 addition & 0 deletions tests/data/test_rag_1.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats are fascinating creatures known for their agility and independence. Domestic cats, descendants of African wildcats, have been companions to humans for over 4,000 years. They possess excellent night vision, allowing them to see in light levels six times lower than what humans need.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
With approximately 32 muscles in each ear, cats have an exceptional sense of hearing and can rotate their ears 180 degrees. Their unique grooming behavior helps to maintain their coat and regulate body temperature. Additionally, cats have a specialized collarbone that enables them to always land on their feet, a skill they demonstrate with remarkable precision.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats are remarkable animals with a range of intriguing characteristics. They have a highly developed sense of smell, with about 50 to 80 million scent receptors in their noses, compared to humans who have only around 5 million.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.4.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats can make over 100 different sounds, including purring, meowing, and hissing, which they use to communicate with their humans and other animals. Their whiskers are highly sensitive and can detect changes in their environment, helping them navigate and judge spaces.
1 change: 1 addition & 0 deletions tests/data/test_rag_1.5.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cats also have a unique grooming behavior, with their tongues covered in tiny, hook-shaped structures called papillae that help them clean their fur and remove loose hair. Additionally, cats sleep for about 12 to 16 hours a day, making them one of the sleepiest of all domestic animals.
1 change: 1 addition & 0 deletions tests/data/test_rag_2.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
There is only one piece of fruit in the fridge. It is an orange, and located on the bottom shelf, behind the pot of stew.
77 changes: 77 additions & 0 deletions tests/integration/api/test_rag_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from pathlib import Path
from openai.types.beta.threads.text import Text

from ...utils.client import client_config_factory


def make_test_assistant(client, model, vector_store_id):
assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
model=model,
tools=[{"type": "file_search"}],
tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}},
)
return assistant


def make_test_run(client, assistant, thread):
run = client.beta.threads.runs.create_and_poll(
assistant_id=assistant.id, thread_id=thread.id
)
return run


def test_rag_needle_haystack():
config = client_config_factory("leapfrogai")
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")
file_path = "../../data"
file_names = [
"test_rag_1.1.txt",
"test_rag_1.2.txt",
"test_rag_1.3.txt",
"test_rag_1.4.txt",
"test_rag_1.5.txt",
"test_rag_2.1.txt",
]
vector_store_files = []
for file_name in file_names:
with open(
f"{Path(os.path.dirname(__file__))}/{file_path}/{file_name}", "rb"
) as file:
vector_store_files.append(
client.beta.vector_stores.files.upload(
vector_store_id=vector_store.id, file=file
)
)

assistant = make_test_assistant(client, config.model, vector_store.id)
thread = client.beta.threads.create()

client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content="Tell me about cats.",
)
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content="There is one piece of fruit in the fridge. What is it and where is it located?",
)
run = make_test_run(client, assistant, thread)

messages = client.beta.threads.messages.list(
thread_id=thread.id, run_id=run.id
).data

# Get the response content from the last message
message_content = messages[-1].content[0].text
alekst23 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(message_content, Text)
assert "orange" in message_content.value
assert len(message_content.annotations) > 0

for a in message_content.annotations:
print(a.text)
19 changes: 15 additions & 4 deletions tests/conformance/utils.py → tests/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,23 @@ def openai_client():
def leapfrogai_client():
return OpenAI(
base_url=os.getenv("LEAPFROGAI_API_URL"),
api_key=os.getenv("LEAPFROGAI_API_KEY"),
api_key=os.getenv("SUPABASE_USER_JWT"),
)


def client_config_factory(client_name):
class ClientConfig:
client: OpenAI
model: str

def __init__(self, client: OpenAI, model: str):
self.client = client
self.model = model


def client_config_factory(client_name) -> ClientConfig:
if client_name == "openai":
return dict(client=openai_client(), model=OPENAI_MODEL)
return ClientConfig(client=openai_client(), model=OPENAI_MODEL)
elif client_name == "leapfrogai":
return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
return ClientConfig(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
else:
raise ValueError(f"Unknown client name: {client_name}")