-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
67 lines (50 loc) · 1.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import openai
import json
from query import select_from_table
from schema import get_schema
from db import create_connection
DATABASE = "./pythonsqlite.db"
def main(conn, question):
with open("auth.json", "r") as f:
auth = json.load(f)
# Load your API key from an environment variable or secret management service
#openai.api_key = os.getenv(auth['api_key'])
openai.api_key = auth['api_key']
print(f"Question: {question}")
# prompt = f"""
# Given the following SQL Schema:{get_schema()}
# Write a query using valid SQLite syntax to answer this question: {question}
# """
prompt = f"""
Given the following SQL Schema:{get_schema()}
Write a query using valid SQLite syntax to answer this question: {question}
If the query involves tables that don't exist or cannot be executed without error, write a select statement that will return only a string describing the error.
"""
response = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
temperature=0,
max_tokens=200
)
q = response["choices"][0]["text"]
print(f"AI-generated SQL query: \n{q}")
print("Answer: \n")
select_from_table(conn, q)
if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument("--query", type=str, default="natural language query")
# args = parser.parse_args()
conn = create_connection(DATABASE)
q = ""
p = """Enter a natural language query.
Type (q/Q) to stop
Query: """
while q not in ["q", "Q"]:
print(p, end="")
q = input()
print()
if q not in ["q", "Q"]:
main(conn, question=q)
print("done")
# main(conn, question=args.query)