From f49cb755a2bc95f19a4635268e9ec38119138284 Mon Sep 17 00:00:00 2001 From: alekst23 Date: Wed, 4 Sep 2024 11:06:51 -0400 Subject: [PATCH] chore: typed return for client factory --- .../test_conformance_assistants.py | 6 ++-- tests/conformance/test_conformance_threads.py | 4 +-- tests/conformance/test_conformance_tools.py | 6 ++-- .../test_conformance_vectorstore.py | 8 ++--- tests/conformance/test_files.py | 8 ++--- tests/conformance/test_messages.py | 6 ++-- tests/integration/api/test_rag_files.py | 31 ++----------------- .../{conformance/utils.py => utils/client.py} | 19 +++++++++--- 8 files changed, 37 insertions(+), 51 deletions(-) rename tests/{conformance/utils.py => utils/client.py} (50%) diff --git a/tests/conformance/test_conformance_assistants.py b/tests/conformance/test_conformance_assistants.py index 8deefb2cf5..1ebcd95b6d 100644 --- a/tests/conformance/test_conformance_assistants.py +++ b/tests/conformance/test_conformance_assistants.py @@ -1,20 +1,20 @@ import pytest from openai.types.beta.assistant import Assistant -from .utils import client_config_factory +from ..utils.client import client_config_factory @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_assistant(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") assistant = client.beta.assistants.create( name="Test Assistant", instructions="You must provide a response based on the attached files.", - model=config["model"], + model=config.model, tools=[{"type": "file_search"}], tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, ) diff --git a/tests/conformance/test_conformance_threads.py b/tests/conformance/test_conformance_threads.py index 94e619df88..91d17c9406 100644 --- a/tests/conformance/test_conformance_threads.py +++ b/tests/conformance/test_conformance_threads.py @@ -2,7 +2,7 @@ from openai.types.beta.thread import Thread from openai.types.beta.threads import Message, TextContentBlock, Text -from .utils import client_config_factory +from ..utils.client import client_config_factory def make_mock_message_object(role, message_text): @@ -37,7 +37,7 @@ def make_mock_message_simple(role, message_text): ) def test_thread(client_name, test_messages): config = client_config_factory(client_name) - client = config["client"] + client = config.client thread = client.beta.threads.create(messages=test_messages) diff --git a/tests/conformance/test_conformance_tools.py b/tests/conformance/test_conformance_tools.py index 241283e0a1..9b69193d58 100644 --- a/tests/conformance/test_conformance_tools.py +++ b/tests/conformance/test_conformance_tools.py @@ -7,7 +7,7 @@ from openai.types.beta.threads.message import Message import re -from .utils import client_config_factory, text_file_path +from ..utils.client import client_config_factory, text_file_path def make_vector_store_with_file(client): @@ -46,10 +46,10 @@ def validate_annotation_format(annotation): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_thread_file_annotations(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = make_vector_store_with_file(client) - assistant = make_test_assistant(client, config["model"], vector_store.id) + assistant = make_test_assistant(client, config.model, vector_store.id) thread = client.beta.threads.create() client.beta.threads.messages.create( diff --git a/tests/conformance/test_conformance_vectorstore.py b/tests/conformance/test_conformance_vectorstore.py index b95dddf89d..25ad52f9d7 100644 --- a/tests/conformance/test_conformance_vectorstore.py +++ b/tests/conformance/test_conformance_vectorstore.py @@ -3,13 +3,13 @@ from openai.types.beta.vector_store import VectorStore from openai.types.beta.vector_store_deleted import VectorStoreDeleted -from .utils import client_config_factory +from ..utils.client import client_config_factory @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_vector_store_create(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = client.beta.vector_stores.create(name="Test data") @@ -19,7 +19,7 @@ def test_vector_store_create(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_vector_store_list(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand client.beta.vector_stores.create(name="Test data") @@ -34,7 +34,7 @@ def test_vector_store_list(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_vector_store_delete(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") diff --git a/tests/conformance/test_files.py b/tests/conformance/test_files.py index 4254e4ee5a..a1510b7900 100644 --- a/tests/conformance/test_files.py +++ b/tests/conformance/test_files.py @@ -6,13 +6,13 @@ ) from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile -from .utils import client_config_factory, text_file_path +from ..utils.client import client_config_factory, text_file_path @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_file_upload(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = client.beta.vector_stores.create(name="Test data") with open(text_file_path(), "rb") as file: @@ -27,7 +27,7 @@ def test_file_upload(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_file_upload_batches(client_name): config = client_config_factory(client_name) - client = config["client"] # shorthand + client = config.client # shorthand vector_store = client.beta.vector_stores.create(name="Test data") @@ -43,7 +43,7 @@ def test_file_upload_batches(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_file_delete(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") with open(text_file_path(), "rb") as file: diff --git a/tests/conformance/test_messages.py b/tests/conformance/test_messages.py index d4e1709ed5..f58f22b9c2 100644 --- a/tests/conformance/test_messages.py +++ b/tests/conformance/test_messages.py @@ -2,13 +2,13 @@ from openai.types.beta.threads.message import Message -from .utils import client_config_factory +from ..utils.client import client_config_factory @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_message_create(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client thread = client.beta.threads.create() message = client.beta.threads.messages.create( @@ -23,7 +23,7 @@ def test_message_create(client_name): @pytest.mark.parametrize("client_name", ["openai", "leapfrogai"]) def test_message_list(client_name): config = client_config_factory(client_name) - client = config["client"] + client = config.client thread = client.beta.threads.create() client.beta.threads.messages.create( diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py index 7be13f837f..e43ef30a47 100644 --- a/tests/integration/api/test_rag_files.py +++ b/tests/integration/api/test_rag_files.py @@ -1,33 +1,8 @@ import os from pathlib import Path from openai.types.beta.threads.text import Text -from openai import OpenAI - -LEAPFROGAI_MODEL = "llama-cpp-python" -OPENAI_MODEL = "gpt-4o-mini" - - -def text_file_path(): - return Path(os.path.dirname(__file__) + "/../../data/test_with_data.txt") - - -def openai_client(): - return OpenAI(api_key=os.getenv("OPENAI_API_KEY")) - - -def leapfrogai_client(): - return OpenAI( - base_url=os.getenv("LEAPFROGAI_API_URL"), - api_key=os.getenv("LEAPFROGAI_API_KEY"), - ) - - -def client_config_factory(client_name): - if client_name == "openai": - return dict(client=openai_client(), model=OPENAI_MODEL) - elif client_name == "leapfrogai": - return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL) +from ...utils.client import client_config_factory def make_test_assistant(client, model, vector_store_id): @@ -50,7 +25,7 @@ def make_test_run(client, assistant, thread): def test_rag_needle_haystack(): config = client_config_factory("leapfrogai") - client = config["client"] + client = config.client vector_store = client.beta.vector_stores.create(name="Test data") file_path = "../../data" @@ -73,7 +48,7 @@ def test_rag_needle_haystack(): ) ) - assistant = make_test_assistant(client, config["model"], vector_store.id) + assistant = make_test_assistant(client, config.model, vector_store.id) thread = client.beta.threads.create() client.beta.threads.messages.create( diff --git a/tests/conformance/utils.py b/tests/utils/client.py similarity index 50% rename from tests/conformance/utils.py rename to tests/utils/client.py index dbfc719d0c..7a58b02f5e 100644 --- a/tests/conformance/utils.py +++ b/tests/utils/client.py @@ -18,12 +18,23 @@ def openai_client(): def leapfrogai_client(): return OpenAI( base_url=os.getenv("LEAPFROGAI_API_URL"), - api_key=os.getenv("LEAPFROGAI_API_KEY"), + api_key=os.getenv("SUPABASE_USER_JWT"), ) -def client_config_factory(client_name): +class ClientConfig: + client: OpenAI + model: str + + def __init__(self, client: OpenAI, model: str): + self.client = client + self.model = model + + +def client_config_factory(client_name) -> ClientConfig: if client_name == "openai": - return dict(client=openai_client(), model=OPENAI_MODEL) + return ClientConfig(client=openai_client(), model=OPENAI_MODEL) elif client_name == "leapfrogai": - return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL) + return ClientConfig(client=leapfrogai_client(), model=LEAPFROGAI_MODEL) + else: + raise ValueError(f"Unknown client name: {client_name}")