Skip to content

Commit

Permalink
chore: typed return for client factory
Browse files Browse the repository at this point in the history
  • Loading branch information
alekst23 committed Sep 4, 2024
1 parent 696a6ed commit 70b6dda
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 51 deletions.
6 changes: 3 additions & 3 deletions tests/conformance/test_conformance_assistants.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import pytest
from openai.types.beta.assistant import Assistant

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_assistant(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")

assistant = client.beta.assistants.create(
name="Test Assistant",
instructions="You must provide a response based on the attached files.",
model=config["model"],
model=config.model,
tools=[{"type": "file_search"}],
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conformance/test_conformance_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from openai.types.beta.thread import Thread
from openai.types.beta.threads import Message, TextContentBlock, Text

from .utils import client_config_factory
from ..utils.client import client_config_factory


def make_mock_message_object(role, message_text):
Expand Down Expand Up @@ -37,7 +37,7 @@ def make_mock_message_simple(role, message_text):
)
def test_thread(client_name, test_messages):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create(messages=test_messages)

Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_conformance_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openai.types.beta.threads.message import Message
import re

from .utils import client_config_factory, text_file_path
from ..utils.client import client_config_factory, text_file_path


def make_vector_store_with_file(client):
Expand Down Expand Up @@ -46,10 +46,10 @@ def validate_annotation_format(annotation):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_thread_file_annotations(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = make_vector_store_with_file(client)
assistant = make_test_assistant(client, config["model"], vector_store.id)
assistant = make_test_assistant(client, config.model, vector_store.id)
thread = client.beta.threads.create()

client.beta.threads.messages.create(
Expand Down
8 changes: 4 additions & 4 deletions tests/conformance/test_conformance_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from openai.types.beta.vector_store import VectorStore
from openai.types.beta.vector_store_deleted import VectorStoreDeleted

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_create(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")

Expand All @@ -19,7 +19,7 @@ def test_vector_store_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_list(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

client.beta.vector_stores.create(name="Test data")

Expand All @@ -34,7 +34,7 @@ def test_vector_store_list(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_vector_store_delete(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")

Expand Down
8 changes: 4 additions & 4 deletions tests/conformance/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
)
from openai.types.beta.vector_stores.vector_store_file import VectorStoreFile

from .utils import client_config_factory, text_file_path
from ..utils.client import client_config_factory, text_file_path


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")
with open(text_file_path(), "rb") as file:
Expand All @@ -27,7 +27,7 @@ def test_file_upload(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_upload_batches(client_name):
config = client_config_factory(client_name)
client = config["client"] # shorthand
client = config.client # shorthand

vector_store = client.beta.vector_stores.create(name="Test data")

Expand All @@ -43,7 +43,7 @@ def test_file_upload_batches(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_file_delete(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")
with open(text_file_path(), "rb") as file:
Expand Down
6 changes: 3 additions & 3 deletions tests/conformance/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from openai.types.beta.threads.message import Message

from .utils import client_config_factory
from ..utils.client import client_config_factory


@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_message_create(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
Expand All @@ -23,7 +23,7 @@ def test_message_create(client_name):
@pytest.mark.parametrize("client_name", ["openai", "leapfrogai"])
def test_message_list(client_name):
config = client_config_factory(client_name)
client = config["client"]
client = config.client

thread = client.beta.threads.create()
client.beta.threads.messages.create(
Expand Down
31 changes: 3 additions & 28 deletions tests/integration/api/test_rag_files.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,8 @@
import os
from pathlib import Path
from openai.types.beta.threads.text import Text
from openai import OpenAI


LEAPFROGAI_MODEL = "llama-cpp-python"
OPENAI_MODEL = "gpt-4o-mini"


def text_file_path():
return Path(os.path.dirname(__file__) + "/../../data/test_with_data.txt")


def openai_client():
return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


def leapfrogai_client():
return OpenAI(
base_url=os.getenv("LEAPFROGAI_API_URL"),
api_key=os.getenv("LEAPFROGAI_API_KEY"),
)


def client_config_factory(client_name):
if client_name == "openai":
return dict(client=openai_client(), model=OPENAI_MODEL)
elif client_name == "leapfrogai":
return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
from ...utils.client import client_config_factory


def make_test_assistant(client, model, vector_store_id):
Expand All @@ -50,7 +25,7 @@ def make_test_run(client, assistant, thread):

def test_rag_needle_haystack():
config = client_config_factory("leapfrogai")
client = config["client"]
client = config.client

vector_store = client.beta.vector_stores.create(name="Test data")
file_path = "../../data"
Expand All @@ -73,7 +48,7 @@ def test_rag_needle_haystack():
)
)

assistant = make_test_assistant(client, config["model"], vector_store.id)
assistant = make_test_assistant(client, config.model, vector_store.id)
thread = client.beta.threads.create()

client.beta.threads.messages.create(
Expand Down
19 changes: 15 additions & 4 deletions tests/conformance/utils.py → tests/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,23 @@ def openai_client():
def leapfrogai_client():
return OpenAI(
base_url=os.getenv("LEAPFROGAI_API_URL"),
api_key=os.getenv("LEAPFROGAI_API_KEY"),
api_key=os.getenv("SUPABASE_USER_JWT"),
)


def client_config_factory(client_name):
class ClientConfig:
client: OpenAI
model: str

def __init__(self, client: OpenAI, model: str):
self.client = client
self.model = model


def client_config_factory(client_name) -> ClientConfig:
if client_name == "openai":
return dict(client=openai_client(), model=OPENAI_MODEL)
return ClientConfig(client=openai_client(), model=OPENAI_MODEL)
elif client_name == "leapfrogai":
return dict(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
return ClientConfig(client=leapfrogai_client(), model=LEAPFROGAI_MODEL)
else:
raise ValueError(f"Unknown client name: {client_name}")

0 comments on commit 70b6dda

Please sign in to comment.