-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# SPDX-FileCopyrightText: 2024 Mark Liffiton <[email protected]> | ||
# | ||
# SPDX-License-Identifier: AGPL-3.0-only | ||
|
||
import argparse | ||
import json | ||
import sqlite3 | ||
from pathlib import Path | ||
|
||
from loaders import load_prompt, load_queries, make_prompt, setup_openai | ||
|
||
TEMPERATURE = 0.25 | ||
MAX_TOKENS = 1000 | ||
|
||
|
||
def get_db(db_path): | ||
db = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES) | ||
db.row_factory = sqlite3.Row | ||
return db | ||
|
||
|
||
def load_data(args): | ||
db = get_db(args.db_path) | ||
|
||
# Load data | ||
queries, headers = load_queries(args.file_path) | ||
|
||
# Initialize the prompt | ||
prompt_func, fields = load_prompt(args.app, headers) | ||
|
||
cur = db.execute("INSERT INTO prompt_set(query_src_file, prompt_func) VALUES (?, ?)", [args.file_path.name, prompt_func.__name__]) | ||
db.commit() | ||
prompt_set_id = cur.lastrowid | ||
|
||
match prompt_func.__name__: | ||
case "make_sufficient_prompt": | ||
model_field = "insufficient_model" | ||
case "make_main_prompt": | ||
model_field = "main_model" | ||
|
||
# Generate prompts and store them | ||
for query in queries: | ||
prompt_msgs = make_prompt(prompt_func, query) | ||
db.execute( | ||
"INSERT INTO prompt(msgs_json, model_response, set_id) VALUES(?, ?, ?)", | ||
[json.dumps(prompt_msgs), query[model_field], prompt_set_id] | ||
) | ||
db.commit() | ||
|
||
print(f"{len(queries)} prompts inserted, prompt set ID = {prompt_set_id}.") | ||
|
||
|
||
def choose_prompt_set(db) -> int: | ||
prompt_sets = db.execute("SELECT * FROM prompt_set").fetchall() | ||
|
||
print("Prompt sets:") | ||
for prompt_set in prompt_sets: | ||
print(f"{prompt_set['id']}: {prompt_set['created']} - {prompt_set['query_src_file']} {prompt_set['prompt_func']}") | ||
prompt_set_id = int(input("Select a prompt set (by ID): ")) | ||
|
||
return prompt_set_id | ||
|
||
|
||
def gen_responses(args): | ||
db = get_db(args.db_path) | ||
|
||
prompt_set_id = choose_prompt_set(db) | ||
|
||
client = setup_openai() | ||
|
||
cur = db.execute("INSERT INTO response_set(model, prompt_set_id) VALUES (?, ?)", [args.model, prompt_set_id]) | ||
db.commit() | ||
response_set_id = cur.lastrowid | ||
|
||
prompts = db.execute("SELECT * FROM prompt WHERE set_id=?", [prompt_set_id]).fetchall() | ||
|
||
for i, prompt in enumerate(prompts): | ||
msgs = json.loads(prompt['msgs_json']) | ||
try: | ||
response = client.chat.completions.create( | ||
model=args.model, | ||
messages=msgs, | ||
temperature=TEMPERATURE, | ||
max_tokens=MAX_TOKENS, | ||
n=1, | ||
) | ||
text = response.choices[0].message.content | ||
except Exception as e: # noqa | ||
response = f"[An error occurred in the openai completion.]\n{e}" | ||
text = response | ||
|
||
db.execute( | ||
"INSERT INTO response(set_id, prompt_id, response, text) VALUES(?, ?, ?, ?)", | ||
[response_set_id, prompt['id'], json.dumps(response.model_dump()), text] | ||
) | ||
print(f"{i+1}/{len(prompts)}") | ||
db.commit() | ||
|
||
|
||
def choose_response_set(db) -> int: | ||
response_sets = db.execute("SELECT response_set.id, response_set.created, response_set.model, prompt_set.query_src_file, prompt_set.prompt_func FROM response_set JOIN prompt_set ON response_set.prompt_set_id=prompt_set.id").fetchall() | ||
|
||
funcs = {} | ||
|
||
print("Response sets:") | ||
for response_set in response_sets: | ||
print(f"{response_set['id']}: {response_set['created']} - {response_set['query_src_file']} {response_set['prompt_func']} -> {response_set['model']}") | ||
funcs[response_set['id']] = response_set['prompt_func'] | ||
|
||
response_set_id = int(input("Select a response set (by ID): ")) | ||
|
||
return response_set_id, funcs[response_set_id] | ||
|
||
|
||
def eval_sufficient(client, model, row): | ||
response = row['text'] | ||
model_response = row['model_response'] | ||
if model_response == "OK.": | ||
# special case; can check with simple text processing | ||
return {"OK.": "OK." in response} | ||
|
||
msgs = [ | ||
{"role": "system", "content": """\ | ||
Evaluate the given text (in <text> delimiters) by comparing it to the given model (in <model> delimiters). | ||
For each specific point in the model, evaluate whether it is addressed or mentioned in the text. Output a JSON object with a key for each point, mapping each to true if the point is included and false otherwise."""}, | ||
{"role": "user", "content": f"<text>\n{response}\n</text>\n\n<model>\n{model_response}\n</model>"}, | ||
] | ||
response = client.chat.completions.create( | ||
model=model, | ||
response_format={ "type": "json_object" }, | ||
messages=msgs, | ||
temperature=TEMPERATURE, | ||
max_tokens=MAX_TOKENS, | ||
n=1, | ||
) | ||
text = response.choices[0].message.content | ||
return json.loads(text) | ||
|
||
|
||
def gen_evals(args): | ||
client = setup_openai() | ||
|
||
db = get_db(args.db_path) | ||
|
||
response_set_id, prompt_func = choose_response_set(db) | ||
|
||
rows = db.execute("SELECT response.text, prompt.model_response FROM response JOIN prompt ON response.prompt_id=prompt.id WHERE response.set_id=?", [response_set_id]).fetchall() | ||
|
||
match prompt_func: | ||
case "make_sufficient_prompt": | ||
eval_func = eval_sufficient | ||
|
||
evals = [] | ||
|
||
for i, row in enumerate(rows): | ||
evaluation = eval_func(client, args.model, row) | ||
evals.append(evaluation) | ||
print(f"{i+1}/{len(rows)}") | ||
if False in evaluation.values(): | ||
print(row['text']) | ||
print(evaluation) | ||
|
||
ok_true = sum(eval_dict.get("OK.") == True for eval_dict in evals) | ||
ok_false = sum(eval_dict.get("OK.") == False for eval_dict in evals) | ||
print(f"OK.: {ok_true} true, {ok_false} false") | ||
other_true = sum(sum(eval_dict.get(key) == True for key in eval_dict if key != "OK.") for eval_dict in evals) | ||
other_false = sum(sum(eval_dict.get(key) == False for key in eval_dict if key != "OK.") for eval_dict in evals) | ||
print(f"Other: {other_true} true, {other_false} false") | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser(description='A tool for running queries against data from a CSV/ODS/XLSX file and evaluating a model\'s responses.') | ||
parser.add_argument('app', type=str, help='The name of the application module from which to load prompts (e.g., codehelp or starburst).') | ||
parser.add_argument('db_path', type=Path, help='Path to the database file storing prompts and evaluations.') | ||
subparsers = parser.add_subparsers(required=True) | ||
parser_load = subparsers.add_parser('load', help='Load a file of queries and model responses; store a generated set of prompts in the database.') | ||
parser_load.add_argument('file_path', type=Path, help='Path to the file to be read.') | ||
parser_load.set_defaults(command_func=load_data) | ||
parser_response = subparsers.add_parser('response', help='Generate a response set for a given prompt set.') | ||
parser_response.add_argument( | ||
'model', type=str, nargs='?', default='gpt-3.5-turbo', | ||
help='(Optional. Default="gpt-3.5-turbo") The LLM to use (gpt-{3.5-turbo, 4o, etc.}).' | ||
) | ||
parser_response.set_defaults(command_func=gen_responses) | ||
parser_eval = subparsers.add_parser('eval', help='Evaluate a given response set.') | ||
parser_eval.add_argument( | ||
'model', type=str, nargs='?', default='gpt-3.5-turbo', | ||
help='(Optional. Default="gpt-3.5-turbo") The LLM to use (gpt-{3.5-turbo, 4o, etc.}).' | ||
) | ||
parser_eval.set_defaults(command_func=gen_evals) | ||
args = parser.parse_args() | ||
|
||
# run the function associated with the chosen command | ||
args.command_func(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
-- SPDX-FileCopyrightText: 2024 Mark Liffiton <[email protected]> | ||
-- | ||
-- SPDX-License-Identifier: AGPL-3.0-only | ||
|
||
PRAGMA foreign_keys = OFF; -- just for not worrying about table deletion order | ||
|
||
DROP TABLE IF EXISTS prompt; | ||
DROP TABLE IF EXISTS prompt_set; | ||
DROP TABLE IF EXISTS response; | ||
DROP TABLE IF EXISTS response_set; | ||
DROP TABLE IF EXISTS eval_prompt; | ||
DROP TABLE IF EXISTS eval; | ||
DROP TABLE IF EXISTS eval_set; | ||
|
||
PRAGMA foreign_keys = ON; -- back on for good | ||
|
||
CREATE TABLE prompt ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
set_id INTEGER NOT NULL, | ||
msgs_json TEXT NOT NULL, | ||
model_response TEXT NOT NULL, -- A model response for this prompt -- form depends on the type of prompt | ||
FOREIGN KEY(set_id) REFERENCES prompt_set(id) ON DELETE CASCADE | ||
); | ||
|
||
CREATE TABLE prompt_set ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
query_src_file TEXT NOT NULL, | ||
prompt_func TEXT NOT NULL | ||
); | ||
|
||
CREATE TABLE response ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
set_id INTEGER NOT NULL, | ||
prompt_id INTEGER NOT NULL, | ||
response TEXT NOT NULL, -- full json response object | ||
text TEXT NOT NULL, -- just the text | ||
FOREIGN KEY(prompt_id) REFERENCES prompt(id), | ||
FOREIGN KEY(set_id) REFERENCES response_set(id) ON DELETE CASCADE | ||
); | ||
|
||
CREATE TABLE response_set ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
prompt_set_id INTEGER NOT NULL, | ||
model TEXT NOT NULL, | ||
FOREIGN KEY(prompt_set_id) REFERENCES prompt_set(id) | ||
); | ||
|
||
CREATE TABLE eval_prompt ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
sys_prompt TEXT NOT NULL | ||
); | ||
|
||
CREATE TABLE eval ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
set_id INTEGER NOT NULL, | ||
prompt_id INTEGER NOT NULL, | ||
eval_prompt_id INTEGER NOT NULL, | ||
response TEXT NOT NULL, -- full json response object | ||
time_ms INTEGER NOT NULL, | ||
FOREIGN KEY(prompt_id) REFERENCES prompt(id), | ||
FOREIGN KEY(eval_prompt_id) REFERENCES eval_prompt(id), | ||
FOREIGN KEY(set_id) REFERENCES eval_set(id) ON DELETE CASCADE | ||
); | ||
|
||
CREATE TABLE eval_set ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
model TEXT NOT NULL | ||
); |