forked from defog-ai/sql-eval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
query_generator.py
62 lines (56 loc) · 2.36 KB
/
query_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import psycopg2
class QueryGenerator:
"""
To customize a query generator, you would implement/override the following functions:
__init__: for initializing the question-specific parameters (eg credentials for the database).
generate_query: implement your query generation logic given a question. add your secret sauce here!
The following function(s) are implemented, as these are common across all query generators:
exec_query: executes the query generated by generate_query; only postgres for now. It has
an implicit dependency on self.db_creds and self.verbose from __init__.
"""
def __init__(self, **kwargs):
pass
def generate_query(
self,
question: str,
instructions: str,
k_shot_prompt: str,
glossary: str,
table_metadata_string: str,
prev_invalid_sql: str,
prev_error_msg: str,
) -> dict:
# generate a query given a question, instructions and k-shot prompt
# any hard-coded logic, prompt-engineering, table-pruning, api calls etc
# should be completely contained within this function
# do add try-except blocks to catch any errors and return an empty string
# these are the keys that you should store in the returned dict:
# query: the generated query
# reason: the reason for the query
# err: the error message if any
# any other fields you might want to track (eg tokens used in query, latency etc)
pass
def exec_query(self, query: str) -> str:
"""
Tries to execute a query and returns an error message if unsuccessful
This function implicitly relies on self.db_creds from init
"""
if self.db_type != "postgres":
raise ValueError("Only postgres is supported for now")
try:
self.conn = psycopg2.connect(**self.db_creds)
self.cur = self.conn.cursor()
self.cur.execute(query)
_ = self.cur.fetchall()
self.cur.close()
self.conn.close()
return ""
except Exception as e:
if self.verbose:
print(f"Error while executing query:\n{type(e)}, {e}")
# cleanup connections
if self.cur:
self.cur.close()
if self.conn:
self.conn.close()
return str(e)