From 87cdb72e4a1e910bdf8644ed0fdd5e62059b668e Mon Sep 17 00:00:00 2001 From: Mark Liffiton Date: Sat, 14 Dec 2024 01:05:59 -0600 Subject: [PATCH] Further refactoring of llm; start to add Gemini support; add migration script for earlier change. --- src/codehelp/helper.py | 33 ++--- src/codehelp/prompts.py | 11 +- src/codehelp/tutor.py | 24 ++-- src/gened/admin.py | 2 +- src/gened/class_config.py | 6 +- src/gened/{openai.py => llm.py} | 133 ++++++------------ .../20241213--rename-openai_key.sql | 10 ++ src/gened/openai_client.py | 73 ++++++++++ src/starburst/helper.py | 11 +- 9 files changed, 162 insertions(+), 141 deletions(-) rename src/gened/{openai.py => llm.py} (60%) create mode 100644 src/gened/migrations/20241213--rename-openai_key.sql create mode 100644 src/gened/openai_client.py diff --git a/src/codehelp/helper.py b/src/codehelp/helper.py index ad9537c..7888c99 100644 --- a/src/codehelp/helper.py +++ b/src/codehelp/helper.py @@ -27,7 +27,7 @@ ) from gened.classes import switch_class from gened.db import get_db -from gened.openai import LLMConfig, get_completion, with_llm +from gened.llm import LLM, with_llm from gened.queries import get_history, get_query from gened.testing.mocks import mock_async_completion @@ -49,7 +49,7 @@ @login_required @class_enabled_required @with_llm(spend_token=False) # get information on the selected LLM, tokens remaining -def help_form(llm: LLMConfig, query_id: int | None = None, class_id: int | None = None, ctx_name: str | None = None) -> str | Response: +def help_form(llm: LLM, query_id: int | None = None, class_id: int | None = None, ctx_name: str | None = None) -> str | Response: db = get_db() auth = get_auth() @@ -122,7 +122,7 @@ def help_view(query_id: int) -> str | Response: return render_template("help_view.html", query=query_row, responses=responses, history=history, topics=topics) -async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]: +async def run_query_prompts(llm: LLM, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]: ''' Run the given query against the coding help system of prompts. Returns a tuple containing: @@ -133,14 +133,12 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: # Launch the "sufficient detail" check concurrently with the main prompt to save time task_main = asyncio.create_task( - get_completion( - llm, + llm.get_completion( messages=prompts.make_main_prompt(code, error, issue, context_str), ) ) task_sufficient = asyncio.create_task( - get_completion( - llm, + llm.get_completion( messages=prompts.make_sufficient_prompt(code, error, issue, context_str), ) ) @@ -155,7 +153,7 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: if "```" in response_txt or "should look like" in response_txt or "should look something like" in response_txt: # That's probably too much code. Let's clean it up... cleanup_prompt = prompts.make_cleanup_prompt(response_text=response_txt) - cleanup_response, cleanup_response_txt = await get_completion(llm, prompt=cleanup_prompt) + cleanup_response, cleanup_response_txt = await llm.get_completion(prompt=cleanup_prompt) responses.append(cleanup_response) response_txt = cleanup_response_txt @@ -174,7 +172,7 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: return responses, {'insufficient': response_sufficient_txt, 'main': response_txt} -def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> int: +def run_query(llm: LLM, context: ContextConfig | None, code: str, error: str, issue: str) -> int: query_id = record_query(context, code, error, issue) responses, texts = asyncio.run(run_query_prompts(llm, context, code, error, issue)) @@ -223,7 +221,7 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[ @login_required @class_enabled_required @with_llm(spend_token=True) -def help_request(llm: LLMConfig) -> Response: +def help_request(llm: LLM) -> Response: if 'context' in request.form: context = get_context_by_name(request.form['context']) if context is None: @@ -244,8 +242,8 @@ def help_request(llm: LLMConfig) -> Response: @bp.route("/load_test", methods=["POST"]) @admin_required -@with_llm(use_system_key=True) # get a populated LLMConfig; not actually used (API is mocked) -def load_test(llm: LLMConfig) -> Response: +@with_llm(use_system_key=True) # get a populated LLM; not actually used (API is mocked) +def load_test(llm: LLM) -> Response: # Require that we're logged in as the load_test admin user auth = get_auth() if auth.user is None or auth.user.display_name != 'load_test': @@ -283,7 +281,7 @@ def post_helpful() -> str: @login_required @tester_required @with_llm(spend_token=False) -def get_topics_html(llm: LLMConfig, query_id: int) -> str: +def get_topics_html(llm: LLM, query_id: int) -> str: topics = get_topics(llm, query_id) if not topics: return render_template("topics_fragment.html", error=True) @@ -295,12 +293,12 @@ def get_topics_html(llm: LLMConfig, query_id: int) -> str: @login_required @tester_required @with_llm(spend_token=False) -def get_topics_raw(llm: LLMConfig, query_id: int) -> list[str]: +def get_topics_raw(llm: LLM, query_id: int) -> list[str]: topics = get_topics(llm, query_id) return topics -def get_topics(llm: LLMConfig, query_id: int) -> list[str]: +def get_topics(llm: LLM, query_id: int) -> list[str]: query_row, responses = get_query(query_id) if not query_row or not responses or 'main' not in responses: @@ -316,10 +314,7 @@ def get_topics(llm: LLMConfig, query_id: int) -> list[str]: responses['main'] ) - response, response_txt = asyncio.run(get_completion( - llm, - messages=messages, - )) + response, response_txt = asyncio.run(llm.get_completion(messages=messages)) # Verify it is actually JSON # May be "Error (..." if an API error occurs, or every now and then may get "Here is the JSON: ..." or similar. diff --git a/src/codehelp/prompts.py b/src/codehelp/prompts.py index b8443c0..1d31a0e 100644 --- a/src/codehelp/prompts.py +++ b/src/codehelp/prompts.py @@ -4,7 +4,8 @@ from jinja2 import Environment -from openai.types.chat import ChatCompletionMessageParam + +from gened.llm import ChatMessage jinja_env = Environment( # noqa: S701 - not worried about XSS in LLM prompts trim_blocks=True, @@ -66,7 +67,7 @@ """) -def make_main_prompt(code: str, error: str, issue: str, context: str | None = None) -> list[ChatCompletionMessageParam]: +def make_main_prompt(code: str, error: str, issue: str, context: str | None = None) -> list[ChatMessage]: error = error.rstrip() issue = issue.rstrip() if error and not issue: @@ -89,7 +90,7 @@ def make_main_prompt(code: str, error: str, issue: str, context: str | None = No """) -def make_sufficient_prompt(code: str, error: str, issue: str, context: str | None) -> list[ChatCompletionMessageParam]: +def make_sufficient_prompt(code: str, error: str, issue: str, context: str | None) -> list[ChatMessage]: error = error.rstrip() issue = issue.rstrip() if error and not issue: @@ -112,9 +113,9 @@ def make_cleanup_prompt(response_text: str) -> str: """ -def make_topics_prompt(code: str, error: str, issue: str, context: str | None, response_text: str) -> list[ChatCompletionMessageParam]: +def make_topics_prompt(code: str, error: str, issue: str, context: str | None, response_text: str) -> list[ChatMessage]: sys_job = "to respond to a student's query as a helpful expert teacher" - messages : list[ChatCompletionMessageParam] = [ + messages : list[ChatMessage] = [ {'role': 'system', 'content': common_template_sys1.render(job=sys_job, code=code, error=error, issue=issue, context=context)}, {'role': 'user', 'content': common_template_user.render(code=code, error=error, issue=issue)}, {'role': 'assistant', 'content': response_text}, diff --git a/src/codehelp/tutor.py b/src/codehelp/tutor.py index 70abb8d..3e429fd 100644 --- a/src/codehelp/tutor.py +++ b/src/codehelp/tutor.py @@ -15,7 +15,6 @@ request, url_for, ) -from openai.types.chat import ChatCompletionMessageParam from werkzeug.wrappers.response import Response from gened.admin import bp as bp_admin @@ -24,7 +23,7 @@ from gened.classes import switch_class from gened.db import get_db from gened.experiments import experiment_required -from gened.openai import LLMConfig, get_completion, with_llm +from gened.llm import LLM, ChatMessage, with_llm from gened.queries import get_query from . import prompts @@ -93,7 +92,7 @@ def tutor_form(class_id: int | None = None, ctx_name: str | None = None) -> str @bp.route("/chat/create", methods=["POST"]) @with_llm() -def start_chat(llm: LLMConfig) -> Response: +def start_chat(llm: LLM) -> Response: topic = request.form['topic'] if 'context' in request.form: @@ -113,7 +112,7 @@ def start_chat(llm: LLMConfig) -> Response: @bp.route("/chat/create_from_query", methods=["POST"]) @with_llm() -def start_chat_from_query(llm: LLMConfig) -> Response: +def start_chat_from_query(llm: LLM) -> Response: topic = request.form['topic'] # build context from the specified query @@ -176,7 +175,7 @@ def get_chat_history(limit: int = 10) -> list[Row]: return history -def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str, str]: +def get_chat(chat_id: int) -> tuple[list[ChatMessage], str, str, str]: db = get_db() auth = get_auth() @@ -210,7 +209,7 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str, return chat, topic, context_name, context_string -def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tuple[dict[str, str], str]: +def get_response(llm: LLM, chat: list[ChatMessage]) -> tuple[dict[str, str], str]: ''' Get a new 'assistant' completion for the specified chat. Parameters: @@ -221,15 +220,12 @@ def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tupl 1) A response object from the OpenAI completion (to be stored in the database). 2) The response text. ''' - response, text = asyncio.run(get_completion( - llm, - messages=chat, - )) + response, text = asyncio.run(llm.get_completion(messages=chat)) return response, text -def save_chat(chat_id: int, chat: list[ChatCompletionMessageParam]) -> None: +def save_chat(chat_id: int, chat: list[ChatMessage]) -> None: db = get_db() db.execute( "UPDATE chats SET chat_json=? WHERE id=?", @@ -238,7 +234,7 @@ def save_chat(chat_id: int, chat: list[ChatCompletionMessageParam]) -> None: db.commit() -def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> None: +def run_chat_round(llm: LLM, chat_id: int, message: str|None = None) -> None: # Get the specified chat try: chat, topic, context_name, context_string = get_chat(chat_id) @@ -256,7 +252,7 @@ def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> No # Get a response (completion) from the API using an expanded version of the chat messages # Insert a system prompt beforehand and an internal monologue after to guide the assistant - expanded_chat : list[ChatCompletionMessageParam] = [ + expanded_chat : list[ChatMessage] = [ {'role': 'system', 'content': prompts.make_chat_sys_prompt(topic, context_string)}, *chat, # chat is a list; expand it here with * {'role': 'assistant', 'content': prompts.tutor_monologue}, @@ -274,7 +270,7 @@ def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> No @bp.route("/message", methods=["POST"]) @with_llm() -def new_message(llm: LLMConfig) -> Response: +def new_message(llm: LLM) -> Response: chat_id = int(request.form["id"]) new_msg = request.form["message"] diff --git a/src/gened/admin.py b/src/gened/admin.py index 97c391a..102293a 100644 --- a/src/gened/admin.py +++ b/src/gened/admin.py @@ -28,7 +28,7 @@ from .auth import admin_required from .csv import csv_response from .db import backup_db, get_db -from .openai import get_models +from .llm import get_models bp = Blueprint('admin', __name__, url_prefix="/admin", template_folder='templates') diff --git a/src/gened/class_config.py b/src/gened/class_config.py index 7db2daa..e9032fe 100644 --- a/src/gened/class_config.py +++ b/src/gened/class_config.py @@ -14,7 +14,7 @@ from .auth import get_auth_class, instructor_required from .db import get_db -from .openai import LLMConfig, get_completion, get_models, with_llm +from .llm import LLM, get_models, with_llm from .tz import date_is_past bp = Blueprint('class_config', __name__, url_prefix="/instructor/config", template_folder='templates') @@ -71,8 +71,8 @@ def config_form() -> str: @bp.route("/test_llm") @with_llm() -def test_llm(llm: LLMConfig) -> str: - response, response_txt = asyncio.run(get_completion(llm, prompt="Please write 'OK'")) +def test_llm(llm: LLM) -> str: + response, response_txt = asyncio.run(llm.get_completion(prompt="Please write 'OK'")) if 'error' in response: return f"Error:
{response_txt}" diff --git a/src/gened/openai.py b/src/gened/llm.py similarity index 60% rename from src/gened/openai.py rename to src/gened/llm.py index 5981d2c..21d6487 100644 --- a/src/gened/openai.py +++ b/src/gened/llm.py @@ -6,50 +6,56 @@ from dataclasses import dataclass, field from functools import wraps from sqlite3 import Row -from typing import ParamSpec, TypeAlias, TypeVar +from typing import Literal, ParamSpec, TypeAlias, TypeVar -import openai from flask import current_app, flash, render_template -from openai import AsyncOpenAI -from openai.types.chat import ChatCompletionMessageParam from .auth import get_auth from .db import get_db +from .openai_client import OpenAIChatMessage, OpenAIClient +LLMProvider: TypeAlias = Literal['google', 'openai'] +ChatMessage: TypeAlias = OpenAIChatMessage -class ClassDisabledError(Exception): - pass - - -class NoKeyFoundError(Exception): - pass - -class NoTokensError(Exception): - pass - - -LLMClient: TypeAlias = AsyncOpenAI +def _get_client(provider: LLMProvider, model: str, api_key: str) -> OpenAIClient: + """ Return a configured OpenAI client object (using OpenAI-compatible + endpoints for other providers) """ + match provider: + case 'google': + # https://ai.google.dev/gemini-api/docs/openai + return OpenAIClient(model, api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/") + case 'openai': + return OpenAIClient(model, api_key) @dataclass -class LLMConfig: - api_key: str +class LLM: + provider: LLMProvider model: str + api_key: str tokens_remaining: int | None = None # None if current user is not using tokens - _client: LLMClient | None = field(default=None, init=False, repr=False) + _client: OpenAIClient | None = field(default=None, init=False, repr=False) # Instantiated only when needed - @property - def client(self) -> LLMClient: - """ Lazy load the LLM client object only when requested. """ + async def get_completion(self, prompt: str | None = None, messages: list[OpenAIChatMessage] | None = None) -> tuple[dict[str, str], str]: + """ Lazily instantiate the LLM client object only when used. """ if self._client is None: - self._client = AsyncOpenAI(api_key=self.api_key) - return self._client + self._client = _get_client(self.provider, self.model, self.api_key) + return await self._client.get_completion(prompt, messages) + + +class ClassDisabledError(Exception): + pass + +class NoKeyFoundError(Exception): + pass +class NoTokensError(Exception): + pass -def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLMConfig: - ''' Get model details and an initialized LLM API client based on the - arguments and the current user and class. +def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLM: + ''' Get an LLM object configured based on the arguments and the current + context (user and class). Procedure, depending on arguments, user, and class: 1) If use_system_key is True, the system API key is always used with no checks. @@ -64,19 +70,20 @@ def _get_llm(*, use_system_key: bool, spend_token: bool) -> LLMConfig: Otherwise, their token count is decremented. Returns: - LLMConfig with an API client and model name. + LLM object. Raises various exceptions in cases where a key and model are not available. ''' db = get_db() - def make_system_client(tokens_remaining: int | None = None) -> LLMConfig: + def make_system_client(tokens_remaining: int | None = None) -> LLM: """ Factory function to initialize a default client (using the system key) only if/when needed. """ system_key = current_app.config["OPENAI_API_KEY"] system_model = current_app.config["SYSTEM_MODEL"] - return LLMConfig( + return LLM( + provider='openai', api_key=system_key, model=system_model, tokens_remaining=tokens_remaining, @@ -113,9 +120,9 @@ def make_system_client(tokens_remaining: int | None = None) -> LLMConfig: if not class_row['llm_api_key']: raise NoKeyFoundError - api_key = class_row['llm_api_key'] - return LLMConfig( - api_key=api_key, + return LLM( + provider='openai', + api_key=class_row['llm_api_key'], model=class_row['model'], ) @@ -194,63 +201,3 @@ def get_models() -> list[Row]: db = get_db() models = db.execute("SELECT * FROM models WHERE active ORDER BY id ASC").fetchall() return models - - -async def get_completion(llm: LLMConfig, prompt: str | None = None, messages: list[ChatCompletionMessageParam] | None = None) -> tuple[dict[str, str], str]: - ''' - model can be any valid OpenAI model name that can be used via the chat completion API. - - Returns: - - A tuple containing: - - An OpenAI response object - - The response text (stripped) - ''' - common_error_text = "Error ({error_type}). Something went wrong with this query. The error has been logged, and we'll work on it. For now, please try again." - try: - if messages is None: - assert prompt is not None - messages = [{"role": "user", "content": prompt}] - - response = await llm.client.chat.completions.create( - model=llm.model, - messages=messages, - temperature=0.25, - max_tokens=1000, - ) - - choice = response.choices[0] - response_txt = choice.message.content or "" - - if choice.finish_reason == "length": # "length" if max_tokens reached - response_txt += "\n\n[error: maximum length exceeded]" - - return response.model_dump(), response_txt.strip() - - except openai.APITimeoutError as e: - err_str = str(e) - response_txt = "Error (APITimeoutError). The system timed out producing the response. Please try again." - current_app.logger.error(f"OpenAI Timeout: {e}") - except openai.RateLimitError as e: - err_str = str(e) - if "exceeded your current quota" in err_str: - response_txt = "Error (RateLimitError). The API key for this class has exceeded its current quota (https://platform.openai.com/docs/guides/rate-limits/usage-tiers). The instructor should check their API plan and billing details. Possibly the key is in the free tier, which does not cover the models used here." - else: - response_txt = "Error (RateLimitError). The system is receiving too many requests right now. Please try again in one minute." - current_app.logger.error(f"OpenAI RateLimitError: {e}") - except openai.AuthenticationError as e: - err_str = str(e) - response_txt = "Error (AuthenticationError). The API key set by the instructor for this class is invalid. The instructor needs to provide a valid API key for this application to work." - current_app.logger.error(f"OpenAI AuthenticationError: {e}") - except openai.BadRequestError as e: - err_str = str(e) - if "maximum context length" in err_str: - response_txt = "Error (BadRequestError). Your query is too long for the model to process. Please reduce the length of your input." - else: - response_txt = common_error_text.format(error_type='BadRequestError') - current_app.logger.error(f"OpenAI BadRequestError: {e}") - except openai.APIError as e: - err_str = str(e) - response_txt = common_error_text.format(error_type='APIError') - current_app.logger.error(f"Exception (OpenAI {type(e).__name__}, but I don't handle that specifically yet): {e}") - - return {'error': err_str}, response_txt diff --git a/src/gened/migrations/20241213--rename-openai_key.sql b/src/gened/migrations/20241213--rename-openai_key.sql new file mode 100644 index 0000000..5492c7f --- /dev/null +++ b/src/gened/migrations/20241213--rename-openai_key.sql @@ -0,0 +1,10 @@ +-- SPDX-FileCopyrightText: 2024 Mark Liffiton +-- +-- SPDX-License-Identifier: AGPL-3.0-only + +BEGIN; + +ALTER TABLE consumers RENAME COLUMN openai_key TO llm_api_key; +ALTER TABLE classes_user RENAME COLUMN openai_key TO llm_api_key; + +COMMIT; diff --git a/src/gened/openai_client.py b/src/gened/openai_client.py new file mode 100644 index 0000000..d92f513 --- /dev/null +++ b/src/gened/openai_client.py @@ -0,0 +1,73 @@ +from typing import TypeAlias + +import openai +from flask import current_app + +OpenAIChatMessage: TypeAlias = openai.types.chat.ChatCompletionMessageParam + + +class OpenAIClient: + def __init__(self, model: str, api_key: str, *, base_url: str | None = None): + if base_url: + self._client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) + else: + self._client = openai.AsyncOpenAI(api_key=api_key) + self._model = model + + async def get_completion(self, prompt: str | None = None, messages: list[OpenAIChatMessage] | None = None) -> tuple[dict[str, str], str]: + ''' + Returns: + - A tuple containing: + - An OpenAI response object + - The response text (stripped) + ''' + common_error_text = "Error ({error_type}). Something went wrong with this query. The error has been logged, and we'll work on it. For now, please try again." + try: + if messages is None: + assert prompt is not None + messages = [{"role": "user", "content": prompt}] + + response = await self._client.chat.completions.create( + model=self._model, + messages=messages, + temperature=0.25, + max_tokens=1000, + ) + + choice = response.choices[0] + response_txt = choice.message.content or "" + + if choice.finish_reason == "length": # "length" if max_tokens reached + response_txt += "\n\n[error: maximum length exceeded]" + + return response.model_dump(), response_txt.strip() + + except openai.APITimeoutError as e: + err_str = str(e) + response_txt = "Error (APITimeoutError). The system timed out producing the response. Please try again." + current_app.logger.error(f"OpenAI Timeout: {e}") + except openai.RateLimitError as e: + err_str = str(e) + if "exceeded your current quota" in err_str: + response_txt = "Error (RateLimitError). The API key for this class has exceeded its current quota (https://platform.openai.com/docs/guides/rate-limits/usage-tiers). The instructor should check their API plan and billing details. Possibly the key is in the free tier, which does not cover the models used here." + else: + response_txt = "Error (RateLimitError). The system is receiving too many requests right now. Please try again in one minute." + current_app.logger.error(f"OpenAI RateLimitError: {e}") + except openai.AuthenticationError as e: + err_str = str(e) + response_txt = "Error (AuthenticationError). The API key set by the instructor for this class is invalid. The instructor needs to provide a valid API key for this application to work." + current_app.logger.error(f"OpenAI AuthenticationError: {e}") + except openai.BadRequestError as e: + err_str = str(e) + if "maximum context length" in err_str: + response_txt = "Error (BadRequestError). Your query is too long for the model to process. Please reduce the length of your input." + else: + response_txt = common_error_text.format(error_type='BadRequestError') + current_app.logger.error(f"OpenAI BadRequestError: {e}") + except openai.APIError as e: + err_str = str(e) + response_txt = common_error_text.format(error_type='APIError') + current_app.logger.error(f"Exception (OpenAI {type(e).__name__}, but I don't handle that specifically yet): {e}") + + return {'error': err_str}, response_txt + diff --git a/src/starburst/helper.py b/src/starburst/helper.py index 956b5f6..9bc9af0 100644 --- a/src/starburst/helper.py +++ b/src/starburst/helper.py @@ -10,7 +10,7 @@ from gened.auth import class_enabled_required, get_auth, login_required from gened.db import get_db -from gened.openai import LLMConfig, get_completion, with_llm +from gened.llm import LLM, with_llm from gened.queries import get_history, get_query from . import prompts @@ -43,7 +43,7 @@ def help_view(query_id: int) -> str: return render_template("help_view.html", query=query_row, responses=responses, history=history) -async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tuple[list[dict[str, str]], dict[str, str]]: +async def run_query_prompts(llm: LLM, assignment: str, topics: str) -> tuple[list[dict[str, str]], dict[str, str]]: ''' Run the given query against the coding help system of prompts. Returns a tuple containing: @@ -51,8 +51,7 @@ async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tup 2) A dictionary of response text, potentially including the key 'main'. ''' task_main = asyncio.create_task( - get_completion( - llm, + llm.get_completion( prompt=prompts.make_main_prompt(assignment, topics), ) ) @@ -70,7 +69,7 @@ async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tup return responses, {'main': response_txt} -def run_query(llm: LLMConfig, assignment: str, topics: str) -> int: +def run_query(llm: LLM, assignment: str, topics: str) -> int: query_id = record_query(assignment, topics) responses, texts = asyncio.run(run_query_prompts(llm, assignment, topics)) @@ -110,7 +109,7 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[ @login_required @class_enabled_required @with_llm(use_system_key=True) -def help_request(llm: LLMConfig) -> Response: +def help_request(llm: LLM) -> Response: assignment = request.form["assignment"] topics = request.form["topics"]