diff --git a/src/leapfrogai_api/backend/rag/index.py b/src/leapfrogai_api/backend/rag/index.py index babbca1c9..6014a9d9a 100644 --- a/src/leapfrogai_api/backend/rag/index.py +++ b/src/leapfrogai_api/backend/rag/index.py @@ -55,12 +55,10 @@ async def index_file(self, vector_store_id: str, file_id: str) -> VectorStoreFil crud_vector_store_file = CRUDVectorStoreFile(db=self.db) crud_vector_store = CRUDVectorStore(db=self.db) - if await crud_vector_store_file.get( + if file_existing := await crud_vector_store_file.get( filters=FilterVectorStoreFile(vector_store_id=vector_store_id, id=file_id) ): - logger.error("File already indexed: %s", file_id) - raise FileAlreadyIndexedError("File already indexed") - + return file_existing if not ( await crud_vector_store.get(filters=FilterVectorStore(id=vector_store_id)) ): diff --git a/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py b/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py index 0770178f6..cb552bc29 100644 --- a/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py +++ b/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py @@ -231,8 +231,6 @@ def sort_by_created_at(msg: Message): ) ) - first_message: ChatMessage = chat_thread_messages[0] - # Holds the converted thread's messages, this will be built up with a series of push operations chat_messages: list[ChatMessage] = [] @@ -250,24 +248,24 @@ def sort_by_created_at(msg: Message): for message in chat_thread_messages: chat_messages.append(message) - use_rag: bool = self.can_use_rag(tool_resources) + # 4 - The RAG results are appended behind the user's query + if self.can_use_rag(tool_resources): + rag_message: str = "Here are relevant docs needed to reply:\n" - rag_message: str = "Here are relevant docs needed to reply:\n" + query_message: ChatMessage = chat_thread_messages[-1] - # 4 - The RAG results are appended behind the user's query - file_ids: set[str] = set() - if use_rag: query_service = QueryService(db=session) file_search: BetaThreadToolResourcesFileSearch = cast( BetaThreadToolResourcesFileSearch, tool_resources.file_search ) - vector_store_ids: list[str] = cast(list[str], file_search.vector_store_ids) + vector_store_ids: list[str] = cast(list[str], file_search.vector_store_ids) + file_ids: set[str] = set() for vector_store_id in vector_store_ids: rag_results_raw: SingleAPIResponse[ SearchResponse ] = await query_service.query_rag( - query=first_message.content, + query=query_message.content, vector_store_id=vector_store_id, ) rag_responses: SearchResponse = SearchResponse( diff --git a/tests/conformance/test_conformance_assistants.py b/tests/conformance/test_conformance_assistants.py index 8deefb2cf..1ebcd95b6 100644 --- a/tests/conformance/test_conformance_assistants.py +++ b/tests/conformance/test_conformance_assistants.py @@ -1,20 +1,20 @@ import pytest from openai.types.beta.assistant import Assistant -from .utils import client_config_factory +from ..utils.client import client_config_factory @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_assistant(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") assistant = client.beta.assistants.create( name="Test Assistant", instructions="You must provide a response based on the attached files.", - model=config["model"], + model=config.model, tools=[{"type": "file_search"}], tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, ) diff --git a/tests/conformance/test_conformance_threads.py b/tests/conformance/test_conformance_threads.py index 94e619df8..91d17c940 100644 --- a/tests/conformance/test_conformance_threads.py +++ b/tests/conformance/test_conformance_threads.py @@ -2,7 +2,7 @@ from openai.types.beta.thread import Thread from openai.types.beta.threads import Message, TextContentBlock, Text -from .utils import client_config_factory +from ..utils.client import client_config_factory def make_mock_message_object(role, message_text): @@ -37,7 +37,7 @@ def make_mock_message_simple(role, message_text): ) def test_thread(client_name, test_messages): config = client_config_factory(client_name) - client = config["client"] + client = config.client thread = client.beta.threads.create(messages=test_messages) diff --git a/tests/conformance/test_conformance_tools.py b/tests/conformance/test_conformance_tools.py index 241283e0a..9b69193d5 100644 --- a/tests/conformance/test_conformance_tools.py +++ b/tests/conformance/test_conformance_tools.py @@ -7,7 +7,7 @@ from openai.types.beta.threads.message import Message import re -from .utils import client_config_factory, text_file_path +from ..utils.client import client_config_factory, text_file_path def make_vector_store_with_file(client): @@ -46,10 +46,10 @@ def validate_annotation_format(annotation): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_thread_file_annotations(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = make_vector_store_with_file(client) - assistant = make_test_assistant(client, config["model"], vector_store.id) + assistant = make_test_assistant(client, config.model, vector_store.id) thread = client.beta.threads.create() client.beta.threads.messages.create( diff --git a/tests/conformance/test_conformance_vectorstore.py b/tests/conformance/test_conformance_vectorstore.py index b95dddf89..25ad52f9d 100644 --- a/tests/conformance/test_conformance_vectorstore.py +++ b/tests/conformance/test_conformance_vectorstore.py @@ -3,13 +3,13 @@ from openai.types.beta.vector_store import VectorStore from openai.types.beta.vector_store_deleted import VectorStoreDeleted -from .utils import client_config_factory +from ..utils.client import client_config_factory @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_vector_store_create(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = client.beta.vector_stores.create(name="Test data") @@ -19,7 +19,7 @@ def test_vector_store_create(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_vector_store_list(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand client.beta.vector_stores.create(name="Test data") @@ -34,7 +34,7 @@ def test_vector_store_list(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_vector_store_delete(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") diff --git a/tests/conformance/test_files.py b/tests/conformance/test_files.py index 4254e4ee5..a1510b790 100644 --- a/tests/conformance/test_files.py +++ b/tests/conformance/test_files.py @@ -6,13 +6,13 @@ ) from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile -from .utils import client_config_factory, text_file_path +from ..utils.client import client_config_factory, text_file_path @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_file_upload(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = client.beta.vector_stores.create(name="Test data") with open(text_file_path(), "rb") as file: @@ -27,7 +27,7 @@ def test_file_upload(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_file_upload_batches(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = client.beta.vector_stores.create(name="Test data") @@ -43,7 +43,7 @@ def test_file_upload_batches(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_file_delete(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") with open(text_file_path(), "rb") as file: diff --git a/tests/conformance/test_messages.py b/tests/conformance/test_messages.py index d4e1709ed..f58f22b9c 100644 --- a/tests/conformance/test_messages.py +++ b/tests/conformance/test_messages.py @@ -2,13 +2,13 @@ from openai.types.beta.threads.message import Message -from .utils import client_config_factory +from ..utils.client import client_config_factory @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_message_create(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client thread = client.beta.threads.create() message = client.beta.threads.messages.create( @@ -23,7 +23,7 @@ def test_message_create(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_message_list(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client thread = client.beta.threads.create() client.beta.threads.messages.create( diff --git a/tests/data/test_rag_1.1.txt b/tests/data/test_rag_1.1.txt new file mode 100644 index 000000000..d72f58b08 --- /dev/null +++ b/tests/data/test_rag_1.1.txt @@ -0,0 +1 @@ +Cats are fascinating creatures known for their agility and independence. Domestic cats, descendants of African wildcats, have been companions to humans for over 4,000 years. They possess excellent night vision, allowing them to see in light levels six times lower than what humans need. diff --git a/tests/data/test_rag_1.2.txt b/tests/data/test_rag_1.2.txt new file mode 100644 index 000000000..24ce16ba2 --- /dev/null +++ b/tests/data/test_rag_1.2.txt @@ -0,0 +1 @@ +With approximately 32 muscles in each ear, cats have an exceptional sense of hearing and can rotate their ears 180 degrees. Their unique grooming behavior helps to maintain their coat and regulate body temperature. Additionally, cats have a specialized collarbone that enables them to always land on their feet, a skill they demonstrate with remarkable precision. diff --git a/tests/data/test_rag_1.3.txt b/tests/data/test_rag_1.3.txt new file mode 100644 index 000000000..49c9bfd92 --- /dev/null +++ b/tests/data/test_rag_1.3.txt @@ -0,0 +1 @@ +Cats are remarkable animals with a range of intriguing characteristics. They have a highly developed sense of smell, with about 50 to 80 million scent receptors in their noses, compared to humans who have only around 5 million. diff --git a/tests/data/test_rag_1.4.txt b/tests/data/test_rag_1.4.txt new file mode 100644 index 000000000..1cfcb0c2d --- /dev/null +++ b/tests/data/test_rag_1.4.txt @@ -0,0 +1 @@ +Cats can make over 100 different sounds, including purring, meowing, and hissing, which they use to communicate with their humans and other animals. Their whiskers are highly sensitive and can detect changes in their environment, helping them navigate and judge spaces. diff --git a/tests/data/test_rag_1.5.txt b/tests/data/test_rag_1.5.txt new file mode 100644 index 000000000..88cdd26dc --- /dev/null +++ b/tests/data/test_rag_1.5.txt @@ -0,0 +1 @@ +Cats also have a unique grooming behavior, with their tongues covered in tiny, hook-shaped structures called papillae that help them clean their fur and remove loose hair. Additionally, cats sleep for about 12 to 16 hours a day, making them one of the sleepiest of all domestic animals. diff --git a/tests/data/test_rag_2.1.txt b/tests/data/test_rag_2.1.txt new file mode 100644 index 000000000..259e5f1f8 --- /dev/null +++ b/tests/data/test_rag_2.1.txt @@ -0,0 +1 @@ +There is only one piece of fruit in the fridge. It is an orange, and located on the bottom shelf, behind the pot of stew. diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py new file mode 100644 index 000000000..a5a743f6e --- /dev/null +++ b/tests/integration/api/test_rag_files.py @@ -0,0 +1,77 @@ +import os +from pathlib import Path +from openai.types.beta.threads.text import Text + +from ...utils.client import client_config_factory + + +def make_test_assistant(client, model, vector_store_id): + assistant = client.beta.assistants.create( + name="Test Assistant", + instructions="You must provide a response based on the attached files.", + model=model, + tools=[{"type": "file_search"}], + tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}}, + ) + return assistant + + +def make_test_run(client, assistant, thread): + run = client.beta.threads.runs.create_and_poll( + assistant_id=assistant.id, thread_id=thread.id + ) + return run + + +def test_rag_needle_haystack(): + config = client_config_factory("leapfrogai") + client = config.client + + vector_store = client.beta.vector_stores.create(name="Test data") + file_path = "../../data" + file_names = [ + "test_rag_1.1.txt", + "test_rag_1.2.txt", + "test_rag_1.3.txt", + "test_rag_1.4.txt", + "test_rag_1.5.txt", + "test_rag_2.1.txt", + ] + vector_store_files = [] + for file_name in file_names: + with open( + f"{Path(os.path.dirname(__file__))}/{file_path}/{file_name}", "rb" + ) as file: + vector_store_files.append( + client.beta.vector_stores.files.upload( + vector_store_id=vector_store.id, file=file + ) + ) + + assistant = make_test_assistant(client, config.model, vector_store.id) + thread = client.beta.threads.create() + + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="Tell me about cats.", + ) + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="There is one piece of fruit in the fridge. What is it and where is it located?", + ) + run = make_test_run(client, assistant, thread) + + messages = client.beta.threads.messages.list( + thread_id=thread.id, run_id=run.id + ).data + + # Get the response content from the last message + message_content = messages[-1].content[0].text + assert isinstance(message_content, Text) + assert "orange" in message_content.value + assert len(message_content.annotations) > 0 + + for a in message_content.annotations: + print(a.text) diff --git a/tests/conformance/utils.py b/tests/utils/client.py similarity index 50% rename from tests/conformance/utils.py rename to tests/utils/client.py index dbfc719d0..7a58b02f5 100644 --- a/tests/conformance/utils.py +++ b/tests/utils/client.py @@ -18,12 +18,23 @@ def openai_client(): def leapfrogai_client(): return OpenAI( base_url=os.getenv("LEAPFROGAI_API_URL"), - api_key=os.getenv("LEAPFROGAI_API_KEY"), + api_key=os.getenv("SUPABASE_USER_JWT"), ) -def client_config_factory(client_name): +class ClientConfig: + client: OpenAI + model: str + + def __init__(self, client: OpenAI, model: str): + self.client = client + self.model = model + + +def client_config_factory(client_name) -> ClientConfig: if client_name == "openai": - return dict(client=openai_client(), model=OPENAI_MODEL) + return ClientConfig(client=openai_client(), model=OPENAI_MODEL) elif client_name == "leapfrogai": - return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL) + return ClientConfig(client=leapfrogai_client(), model=LEAPFROGAI_MODEL) + else: + raise ValueError(f"Unknown client name: {client_name}")