diff --git a/src/codehelp/helper.py b/src/codehelp/helper.py index 5b37ce9..ad9537c 100644 --- a/src/codehelp/helper.py +++ b/src/codehelp/helper.py @@ -129,23 +129,18 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: 1) A list of response objects from the OpenAI completion (to be stored in the database) 2) A dictionary of response text, potentially including keys 'insufficient' and 'main'. ''' - client = llm.client - model = llm.model - context_str = context.prompt_str() if context is not None else None # Launch the "sufficient detail" check concurrently with the main prompt to save time task_main = asyncio.create_task( get_completion( - client, - model=model, + llm, messages=prompts.make_main_prompt(code, error, issue, context_str), ) ) task_sufficient = asyncio.create_task( get_completion( - client, - model=model, + llm, messages=prompts.make_sufficient_prompt(code, error, issue, context_str), ) ) @@ -160,7 +155,7 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: if "```" in response_txt or "should look like" in response_txt or "should look something like" in response_txt: # That's probably too much code. Let's clean it up... cleanup_prompt = prompts.make_cleanup_prompt(response_text=response_txt) - cleanup_response, cleanup_response_txt = await get_completion(client, model, prompt=cleanup_prompt) + cleanup_response, cleanup_response_txt = await get_completion(llm, prompt=cleanup_prompt) responses.append(cleanup_response) response_txt = cleanup_response_txt @@ -322,8 +317,7 @@ def get_topics(llm: LLMConfig, query_id: int) -> list[str]: ) response, response_txt = asyncio.run(get_completion( - client=llm.client, - model=llm.model, + llm, messages=messages, )) diff --git a/src/codehelp/tutor.py b/src/codehelp/tutor.py index 0029b27..70abb8d 100644 --- a/src/codehelp/tutor.py +++ b/src/codehelp/tutor.py @@ -222,8 +222,7 @@ def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tupl 2) The response text. ''' response, text = asyncio.run(get_completion( - client=llm.client, - model=llm.model, + llm, messages=chat, )) diff --git a/src/gened/class_config.py b/src/gened/class_config.py index da6f411..7db2daa 100644 --- a/src/gened/class_config.py +++ b/src/gened/class_config.py @@ -72,11 +72,7 @@ def config_form() -> str: @bp.route("/test_llm") @with_llm() def test_llm(llm: LLMConfig) -> str: - response, response_txt = asyncio.run(get_completion( - client=llm.client, - model=llm.model, - prompt="Please write 'OK'" - )) + response, response_txt = asyncio.run(get_completion(llm, prompt="Please write 'OK'")) if 'error' in response: return f"Error:
{response_txt}" diff --git a/src/gened/openai.py b/src/gened/openai.py index 88b7973..5981d2c 100644 --- a/src/gened/openai.py +++ b/src/gened/openai.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from functools import wraps from sqlite3 import Row -from typing import ParamSpec, TypeVar +from typing import ParamSpec, TypeAlias, TypeVar import openai from flask import current_app, flash, render_template @@ -29,15 +29,18 @@ class NoTokensError(Exception): pass +LLMClient: TypeAlias = AsyncOpenAI + + @dataclass class LLMConfig: api_key: str model: str tokens_remaining: int | None = None # None if current user is not using tokens - _client: AsyncOpenAI | None = field(default=None, init=False, repr=False) + _client: LLMClient | None = field(default=None, init=False, repr=False) @property - def client(self) -> AsyncOpenAI: + def client(self) -> LLMClient: """ Lazy load the LLM client object only when requested. """ if self._client is None: self._client = AsyncOpenAI(api_key=self.api_key) @@ -45,7 +48,7 @@ def client(self) -> AsyncOpenAI: def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLMConfig: - ''' Get model details and an initialized OpenAI client based on the + ''' Get model details and an initialized LLM API client based on the arguments and the current user and class. Procedure, depending on arguments, user, and class: @@ -61,7 +64,7 @@ def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLMConfig: Otherwise, their token count is decremented. Returns: - LLMConfig with an OpenAI client and model name. + LLMConfig with an API client and model name. Raises various exceptions in cases where a key and model are not available. ''' @@ -177,7 +180,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R: flash("Error: No API key set. An API key must be set by the instructor before this page can be used.") return render_template("error.html") except NoTokensError: - flash("You have used all of your free queries. If you are using this application in a class, please connect using the link from your class for continued access. Otherwise, you can create a class and add an OpenAI API key or contact us if you want to continue using this application.", "warning") + flash("You have used all of your free queries. If you are using this application in a class, please connect using the link from your class for continued access. Otherwise, you can create a class and add an API key or contact us if you want to continue using this application.", "warning") return render_template("error.html") kwargs['llm'] = llm @@ -193,7 +196,7 @@ def get_models() -> list[Row]: return models -async def get_completion(client: AsyncOpenAI, model: str, prompt: str | None = None, messages: list[ChatCompletionMessageParam] | None = None) -> tuple[dict[str, str], str]: +async def get_completion(llm: LLMConfig, prompt: str | None = None, messages: list[ChatCompletionMessageParam] | None = None) -> tuple[dict[str, str], str]: ''' model can be any valid OpenAI model name that can be used via the chat completion API. @@ -208,8 +211,8 @@ async def get_completion(client: AsyncOpenAI, model: str, prompt: str | None = N assert prompt is not None messages = [{"role": "user", "content": prompt}] - response = await client.chat.completions.create( - model=model, + response = await llm.client.chat.completions.create( + model=llm.model, messages=messages, temperature=0.25, max_tokens=1000, diff --git a/src/starburst/helper.py b/src/starburst/helper.py index 10c3d65..956b5f6 100644 --- a/src/starburst/helper.py +++ b/src/starburst/helper.py @@ -52,8 +52,7 @@ async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tup ''' task_main = asyncio.create_task( get_completion( - client=llm.client, - model=llm.model, + llm, prompt=prompts.make_main_prompt(assignment, topics), ) )