Skip to content

Commit

Permalink
Further refactoring of llm; start to add Gemini support; add migratio…
Browse files Browse the repository at this point in the history
…n script for earlier change.
  • Loading branch information
liffiton committed Dec 14, 2024
1 parent 6395766 commit 87cdb72
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 141 deletions.
33 changes: 14 additions & 19 deletions src/codehelp/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from gened.classes import switch_class
from gened.db import get_db
from gened.openai import LLMConfig, get_completion, with_llm
from gened.llm import LLM, with_llm
from gened.queries import get_history, get_query
from gened.testing.mocks import mock_async_completion

Expand All @@ -49,7 +49,7 @@
@login_required
@class_enabled_required
@with_llm(spend_token=False) # get information on the selected LLM, tokens remaining
def help_form(llm: LLMConfig, query_id: int | None = None, class_id: int | None = None, ctx_name: str | None = None) -> str | Response:
def help_form(llm: LLM, query_id: int | None = None, class_id: int | None = None, ctx_name: str | None = None) -> str | Response:
db = get_db()
auth = get_auth()

Expand Down Expand Up @@ -122,7 +122,7 @@ def help_view(query_id: int) -> str | Response:
return render_template("help_view.html", query=query_row, responses=responses, history=history, topics=topics)


async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]:
async def run_query_prompts(llm: LLM, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]:
''' Run the given query against the coding help system of prompts.
Returns a tuple containing:
Expand All @@ -133,14 +133,12 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code:

# Launch the "sufficient detail" check concurrently with the main prompt to save time
task_main = asyncio.create_task(
get_completion(
llm,
llm.get_completion(
messages=prompts.make_main_prompt(code, error, issue, context_str),
)
)
task_sufficient = asyncio.create_task(
get_completion(
llm,
llm.get_completion(
messages=prompts.make_sufficient_prompt(code, error, issue, context_str),
)
)
Expand All @@ -155,7 +153,7 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code:
if "```" in response_txt or "should look like" in response_txt or "should look something like" in response_txt:
# That's probably too much code. Let's clean it up...
cleanup_prompt = prompts.make_cleanup_prompt(response_text=response_txt)
cleanup_response, cleanup_response_txt = await get_completion(llm, prompt=cleanup_prompt)
cleanup_response, cleanup_response_txt = await llm.get_completion(prompt=cleanup_prompt)
responses.append(cleanup_response)
response_txt = cleanup_response_txt

Expand All @@ -174,7 +172,7 @@ async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code:
return responses, {'insufficient': response_sufficient_txt, 'main': response_txt}


def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> int:
def run_query(llm: LLM, context: ContextConfig | None, code: str, error: str, issue: str) -> int:
query_id = record_query(context, code, error, issue)

responses, texts = asyncio.run(run_query_prompts(llm, context, code, error, issue))
Expand Down Expand Up @@ -223,7 +221,7 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[
@login_required
@class_enabled_required
@with_llm(spend_token=True)
def help_request(llm: LLMConfig) -> Response:
def help_request(llm: LLM) -> Response:
if 'context' in request.form:
context = get_context_by_name(request.form['context'])
if context is None:
Expand All @@ -244,8 +242,8 @@ def help_request(llm: LLMConfig) -> Response:

@bp.route("/load_test", methods=["POST"])
@admin_required
@with_llm(use_system_key=True) # get a populated LLMConfig; not actually used (API is mocked)
def load_test(llm: LLMConfig) -> Response:
@with_llm(use_system_key=True) # get a populated LLM; not actually used (API is mocked)
def load_test(llm: LLM) -> Response:
# Require that we're logged in as the load_test admin user
auth = get_auth()
if auth.user is None or auth.user.display_name != 'load_test':
Expand Down Expand Up @@ -283,7 +281,7 @@ def post_helpful() -> str:
@login_required
@tester_required
@with_llm(spend_token=False)
def get_topics_html(llm: LLMConfig, query_id: int) -> str:
def get_topics_html(llm: LLM, query_id: int) -> str:
topics = get_topics(llm, query_id)
if not topics:
return render_template("topics_fragment.html", error=True)
Expand All @@ -295,12 +293,12 @@ def get_topics_html(llm: LLMConfig, query_id: int) -> str:
@login_required
@tester_required
@with_llm(spend_token=False)
def get_topics_raw(llm: LLMConfig, query_id: int) -> list[str]:
def get_topics_raw(llm: LLM, query_id: int) -> list[str]:
topics = get_topics(llm, query_id)
return topics


def get_topics(llm: LLMConfig, query_id: int) -> list[str]:
def get_topics(llm: LLM, query_id: int) -> list[str]:
query_row, responses = get_query(query_id)

if not query_row or not responses or 'main' not in responses:
Expand All @@ -316,10 +314,7 @@ def get_topics(llm: LLMConfig, query_id: int) -> list[str]:
responses['main']
)

response, response_txt = asyncio.run(get_completion(
llm,
messages=messages,
))
response, response_txt = asyncio.run(llm.get_completion(messages=messages))

# Verify it is actually JSON
# May be "Error (..." if an API error occurs, or every now and then may get "Here is the JSON: ..." or similar.
Expand Down
11 changes: 6 additions & 5 deletions src/codehelp/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


from jinja2 import Environment
from openai.types.chat import ChatCompletionMessageParam

from gened.llm import ChatMessage

jinja_env = Environment( # noqa: S701 - not worried about XSS in LLM prompts
trim_blocks=True,
Expand Down Expand Up @@ -66,7 +67,7 @@
""")


def make_main_prompt(code: str, error: str, issue: str, context: str | None = None) -> list[ChatCompletionMessageParam]:
def make_main_prompt(code: str, error: str, issue: str, context: str | None = None) -> list[ChatMessage]:
error = error.rstrip()
issue = issue.rstrip()
if error and not issue:
Expand All @@ -89,7 +90,7 @@ def make_main_prompt(code: str, error: str, issue: str, context: str | None = No
""")


def make_sufficient_prompt(code: str, error: str, issue: str, context: str | None) -> list[ChatCompletionMessageParam]:
def make_sufficient_prompt(code: str, error: str, issue: str, context: str | None) -> list[ChatMessage]:
error = error.rstrip()
issue = issue.rstrip()
if error and not issue:
Expand All @@ -112,9 +113,9 @@ def make_cleanup_prompt(response_text: str) -> str:
"""


def make_topics_prompt(code: str, error: str, issue: str, context: str | None, response_text: str) -> list[ChatCompletionMessageParam]:
def make_topics_prompt(code: str, error: str, issue: str, context: str | None, response_text: str) -> list[ChatMessage]:
sys_job = "to respond to a student's query as a helpful expert teacher"
messages : list[ChatCompletionMessageParam] = [
messages : list[ChatMessage] = [
{'role': 'system', 'content': common_template_sys1.render(job=sys_job, code=code, error=error, issue=issue, context=context)},
{'role': 'user', 'content': common_template_user.render(code=code, error=error, issue=issue)},
{'role': 'assistant', 'content': response_text},
Expand Down
24 changes: 10 additions & 14 deletions src/codehelp/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
request,
url_for,
)
from openai.types.chat import ChatCompletionMessageParam
from werkzeug.wrappers.response import Response

from gened.admin import bp as bp_admin
Expand All @@ -24,7 +23,7 @@
from gened.classes import switch_class
from gened.db import get_db
from gened.experiments import experiment_required
from gened.openai import LLMConfig, get_completion, with_llm
from gened.llm import LLM, ChatMessage, with_llm
from gened.queries import get_query

from . import prompts
Expand Down Expand Up @@ -93,7 +92,7 @@ def tutor_form(class_id: int | None = None, ctx_name: str | None = None) -> str

@bp.route("/chat/create", methods=["POST"])
@with_llm()
def start_chat(llm: LLMConfig) -> Response:
def start_chat(llm: LLM) -> Response:
topic = request.form['topic']

if 'context' in request.form:
Expand All @@ -113,7 +112,7 @@ def start_chat(llm: LLMConfig) -> Response:

@bp.route("/chat/create_from_query", methods=["POST"])
@with_llm()
def start_chat_from_query(llm: LLMConfig) -> Response:
def start_chat_from_query(llm: LLM) -> Response:
topic = request.form['topic']

# build context from the specified query
Expand Down Expand Up @@ -176,7 +175,7 @@ def get_chat_history(limit: int = 10) -> list[Row]:
return history


def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str, str]:
def get_chat(chat_id: int) -> tuple[list[ChatMessage], str, str, str]:
db = get_db()
auth = get_auth()

Expand Down Expand Up @@ -210,7 +209,7 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str,
return chat, topic, context_name, context_string


def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tuple[dict[str, str], str]:
def get_response(llm: LLM, chat: list[ChatMessage]) -> tuple[dict[str, str], str]:
''' Get a new 'assistant' completion for the specified chat.
Parameters:
Expand All @@ -221,15 +220,12 @@ def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tupl
1) A response object from the OpenAI completion (to be stored in the database).
2) The response text.
'''
response, text = asyncio.run(get_completion(
llm,
messages=chat,
))
response, text = asyncio.run(llm.get_completion(messages=chat))

return response, text


def save_chat(chat_id: int, chat: list[ChatCompletionMessageParam]) -> None:
def save_chat(chat_id: int, chat: list[ChatMessage]) -> None:
db = get_db()
db.execute(
"UPDATE chats SET chat_json=? WHERE id=?",
Expand All @@ -238,7 +234,7 @@ def save_chat(chat_id: int, chat: list[ChatCompletionMessageParam]) -> None:
db.commit()


def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> None:
def run_chat_round(llm: LLM, chat_id: int, message: str|None = None) -> None:
# Get the specified chat
try:
chat, topic, context_name, context_string = get_chat(chat_id)
Expand All @@ -256,7 +252,7 @@ def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> No

# Get a response (completion) from the API using an expanded version of the chat messages
# Insert a system prompt beforehand and an internal monologue after to guide the assistant
expanded_chat : list[ChatCompletionMessageParam] = [
expanded_chat : list[ChatMessage] = [
{'role': 'system', 'content': prompts.make_chat_sys_prompt(topic, context_string)},
*chat, # chat is a list; expand it here with *
{'role': 'assistant', 'content': prompts.tutor_monologue},
Expand All @@ -274,7 +270,7 @@ def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> No

@bp.route("/message", methods=["POST"])
@with_llm()
def new_message(llm: LLMConfig) -> Response:
def new_message(llm: LLM) -> Response:
chat_id = int(request.form["id"])
new_msg = request.form["message"]

Expand Down
2 changes: 1 addition & 1 deletion src/gened/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .auth import admin_required
from .csv import csv_response
from .db import backup_db, get_db
from .openai import get_models
from .llm import get_models

bp = Blueprint('admin', __name__, url_prefix="/admin", template_folder='templates')

Expand Down
6 changes: 3 additions & 3 deletions src/gened/class_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .auth import get_auth_class, instructor_required
from .db import get_db
from .openai import LLMConfig, get_completion, get_models, with_llm
from .llm import LLM, get_models, with_llm
from .tz import date_is_past

bp = Blueprint('class_config', __name__, url_prefix="/instructor/config", template_folder='templates')
Expand Down Expand Up @@ -71,8 +71,8 @@ def config_form() -> str:

@bp.route("/test_llm")
@with_llm()
def test_llm(llm: LLMConfig) -> str:
response, response_txt = asyncio.run(get_completion(llm, prompt="Please write 'OK'"))
def test_llm(llm: LLM) -> str:
response, response_txt = asyncio.run(llm.get_completion(prompt="Please write 'OK'"))

if 'error' in response:
return f"<b>Error:</b><br>{response_txt}"
Expand Down
Loading

0 comments on commit 87cdb72

Please sign in to comment.