-
Notifications
You must be signed in to change notification settings - Fork 62
/
anthropic_runner.py
159 lines (151 loc) · 6.83 KB
/
anthropic_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import copy
import os
from eval.eval import compare_query_results
import pandas as pd
from psycopg2.extensions import QueryCanceledError
from query_generators.anthropic import AnthropicQueryGenerator
from tqdm import tqdm
from utils.questions import prepare_questions_df
from utils.creds import db_creds_all
from utils.reporting import upload_results
def run_anthropic_eval(args):
# get params from args
questions_file_list = args.questions_file
prompt_file_list = args.prompt_file
output_file_list = args.output_file
num_questions = args.num_questions
k_shot = args.k_shot
db_type = args.db_type
cot_table_alias = args.cot_table_alias
for questions_file, prompt_file, output_file in zip(
questions_file_list, prompt_file_list, output_file_list
):
print(f"Using prompt file {prompt_file}")
# get questions
print("Preparing questions...")
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
question_query_df = prepare_questions_df(
questions_file, db_type, num_questions, k_shot, cot_table_alias
)
input_rows = question_query_df.to_dict("records")
output_rows = []
with ThreadPoolExecutor(args.parallel_threads) as executor:
# for each query in the csv, generate a query using the generator asynchronously
futures = []
for row in input_rows:
# get db creds for each row's db_name
db_name = row["db_name"]
db_creds = db_creds_all[row["db_type"]]
qg = AnthropicQueryGenerator(
db_creds=copy.deepcopy(db_creds),
db_name=db_name,
db_type=db_type,
model=args.model,
prompt_file=prompt_file,
timeout=args.timeout_gen,
use_public_data=not args.use_private_data,
verbose=args.verbose,
)
generated_query_fut = executor.submit(
qg.generate_query,
question=row["question"],
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
prev_invalid_sql=row["prev_invalid_sql"],
prev_error_msg=row["prev_error_msg"],
cot_instructions=row["cot_instructions"],
columns_to_keep=args.num_columns,
shuffle=args.shuffle_metadata,
)
futures.append(generated_query_fut)
total_tried = 0
total_correct = 0
for f in (pbar := tqdm(as_completed(futures), total=len(futures))):
total_tried += 1
i = futures.index(f)
row = input_rows[i]
result_dict = f.result()
query_gen = result_dict["query"]
reason = result_dict["reason"]
err = result_dict["err"]
# save custom metrics
if "latency_seconds" in result_dict:
row["latency_seconds"] = result_dict["latency_seconds"]
if "tokens_used" in result_dict:
row["tokens_used"] = result_dict["tokens_used"]
row["generated_query"] = query_gen
row["reason"] = reason
row["error_msg"] = err
# save failures into relevant columns in the dataframe
if "GENERATION ERROR" in err:
row["error_query_gen"] = 1
elif "EXECUTION ERROR" in err:
row["error_db_exec"] = 1
elif "TIMEOUT" in err:
row["timeout"] = 1
else:
expected_query = row["query"]
db_name = row["db_name"]
db_type = row["db_type"]
question = row["question"]
query_category = row["query_category"]
table_metadata_string = row["table_metadata_string"]
exact_match = correct = 0
db_creds = db_creds_all[row["db_type"]]
# try executing the queries and compare the results if they succeed
try:
exact_match, correct = compare_query_results(
query_gold=expected_query,
query_gen=query_gen,
db_name=db_name,
db_type=db_type,
db_creds=db_creds_all[db_type],
timeout=args.timeout_exec,
question=question,
query_category=query_category,
table_metadata_string=table_metadata_string,
decimal_points=args.decimal_points,
)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
if correct:
total_correct += 1
except QueryCanceledError as e:
row["timeout"] = 1
row["error_msg"] = f"QUERY EXECUTION TIMEOUT: {e}"
except Exception as e:
row["error_db_exec"] = 1
row["error_msg"] = f"QUERY EXECUTION ERROR: {e}"
output_rows.append(row)
pbar.set_description(
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)"
)
output_df = pd.DataFrame(output_rows)
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_df.to_csv(output_file, index=False, float_format="%.2f")
# get average rate of correct results
avg_subset = output_df["correct"].sum() / len(output_df)
print(f"Average correct rate: {avg_subset:.2f}")
results = output_df.to_dict("records")
# upload results
with open(prompt_file, "r") as f:
prompt = f.read()
if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="anthropic",
prompt=prompt,
args=args,
)