diff --git a/src/codehelp/helper.py b/src/codehelp/helper.py
index 5b37ce9..ad9537c 100644
--- a/src/codehelp/helper.py
+++ b/src/codehelp/helper.py
@@ -129,23 +129,18 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code:
1) A list of response objects from the OpenAI completion (to be stored in the database)
2) A dictionary of response text, potentially including keys 'insufficient' and 'main'.
'''
- client = llm.client
- model = llm.model
-
context_str = context.prompt_str() if context is not None else None
# Launch the "sufficient detail" check concurrently with the main prompt to save time
task_main = asyncio.create_task(
get_completion(
- client,
- model=model,
+ llm,
messages=prompts.make_main_prompt(code, error, issue, context_str),
)
)
task_sufficient = asyncio.create_task(
get_completion(
- client,
- model=model,
+ llm,
messages=prompts.make_sufficient_prompt(code, error, issue, context_str),
)
)
@@ -160,7 +155,7 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code:
if "```" in response_txt or "should look like" in response_txt or "should look something like" in response_txt:
# That's probably too much code. Let's clean it up...
cleanup_prompt = prompts.make_cleanup_prompt(response_text=response_txt)
- cleanup_response, cleanup_response_txt = await get_completion(client, model, prompt=cleanup_prompt)
+ cleanup_response, cleanup_response_txt = await get_completion(llm, prompt=cleanup_prompt)
responses.append(cleanup_response)
response_txt = cleanup_response_txt
@@ -322,8 +317,7 @@ def get_topics(llm: LLMConfig, query_id: int) -> list[str]:
)
response, response_txt = asyncio.run(get_completion(
- client=llm.client,
- model=llm.model,
+ llm,
messages=messages,
))
diff --git a/src/codehelp/tutor.py b/src/codehelp/tutor.py
index 0029b27..70abb8d 100644
--- a/src/codehelp/tutor.py
+++ b/src/codehelp/tutor.py
@@ -222,8 +222,7 @@ def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tupl
2) The response text.
'''
response, text = asyncio.run(get_completion(
- client=llm.client,
- model=llm.model,
+ llm,
messages=chat,
))
diff --git a/src/gened/class_config.py b/src/gened/class_config.py
index da6f411..7db2daa 100644
--- a/src/gened/class_config.py
+++ b/src/gened/class_config.py
@@ -72,11 +72,7 @@ def config_form() -> str:
@bp.route("/test_llm")
@with_llm()
def test_llm(llm: LLMConfig) -> str:
- response, response_txt = asyncio.run(get_completion(
- client=llm.client,
- model=llm.model,
- prompt="Please write 'OK'"
- ))
+ response, response_txt = asyncio.run(get_completion(llm, prompt="Please write 'OK'"))
if 'error' in response:
return f"Error:
{response_txt}"
diff --git a/src/gened/openai.py b/src/gened/openai.py
index 88b7973..5981d2c 100644
--- a/src/gened/openai.py
+++ b/src/gened/openai.py
@@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from functools import wraps
from sqlite3 import Row
-from typing import ParamSpec, TypeVar
+from typing import ParamSpec, TypeAlias, TypeVar
import openai
from flask import current_app, flash, render_template
@@ -29,15 +29,18 @@ class NoTokensError(Exception):
pass
+LLMClient: TypeAlias = AsyncOpenAI
+
+
@dataclass
class LLMConfig:
api_key: str
model: str
tokens_remaining: int | None = None # None if current user is not using tokens
- _client: AsyncOpenAI | None = field(default=None, init=False, repr=False)
+ _client: LLMClient | None = field(default=None, init=False, repr=False)
@property
- def client(self) -> AsyncOpenAI:
+ def client(self) -> LLMClient:
""" Lazy load the LLM client object only when requested. """
if self._client is None:
self._client = AsyncOpenAI(api_key=self.api_key)
@@ -45,7 +48,7 @@ def client(self) -> AsyncOpenAI:
def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLMConfig:
- ''' Get model details and an initialized OpenAI client based on the
+ ''' Get model details and an initialized LLM API client based on the
arguments and the current user and class.
Procedure, depending on arguments, user, and class:
@@ -61,7 +64,7 @@ def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLMConfig:
Otherwise, their token count is decremented.
Returns:
- LLMConfig with an OpenAI client and model name.
+ LLMConfig with an API client and model name.
Raises various exceptions in cases where a key and model are not available.
'''
@@ -177,7 +180,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R:
flash("Error: No API key set. An API key must be set by the instructor before this page can be used.")
return render_template("error.html")
except NoTokensError:
- flash("You have used all of your free queries. If you are using this application in a class, please connect using the link from your class for continued access. Otherwise, you can create a class and add an OpenAI API key or contact us if you want to continue using this application.", "warning")
+ flash("You have used all of your free queries. If you are using this application in a class, please connect using the link from your class for continued access. Otherwise, you can create a class and add an API key or contact us if you want to continue using this application.", "warning")
return render_template("error.html")
kwargs['llm'] = llm
@@ -193,7 +196,7 @@ def get_models() -> list[Row]:
return models
-async def get_completion(client: AsyncOpenAI, model: str, prompt: str | None = None, messages: list[ChatCompletionMessageParam] | None = None) -> tuple[dict[str, str], str]:
+async def get_completion(llm: LLMConfig, prompt: str | None = None, messages: list[ChatCompletionMessageParam] | None = None) -> tuple[dict[str, str], str]:
'''
model can be any valid OpenAI model name that can be used via the chat completion API.
@@ -208,8 +211,8 @@ async def get_completion(client: AsyncOpenAI, model: str, prompt: str | None = N
assert prompt is not None
messages = [{"role": "user", "content": prompt}]
- response = await client.chat.completions.create(
- model=model,
+ response = await llm.client.chat.completions.create(
+ model=llm.model,
messages=messages,
temperature=0.25,
max_tokens=1000,
diff --git a/src/starburst/helper.py b/src/starburst/helper.py
index 10c3d65..956b5f6 100644
--- a/src/starburst/helper.py
+++ b/src/starburst/helper.py
@@ -52,8 +52,7 @@ async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tup
'''
task_main = asyncio.create_task(
get_completion(
- client=llm.client,
- model=llm.model,
+ llm,
prompt=prompts.make_main_prompt(assignment, topics),
)
)