Skip to content

Commit

Permalink
poc to return similarity score
Browse files Browse the repository at this point in the history
  • Loading branch information
jrpereirajr committed Aug 16, 2024
1 parent 464ece7 commit 333a68c
Showing 1 changed file with 322 additions and 0 deletions.
322 changes: 322 additions & 0 deletions python/sqlzilla/sqlzilla copy 3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
from typing import Dict, List
from sqlalchemy import create_engine
import hashlib
import pandas as pd;
import re

from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_iris import IRISVector

class SQLZilla:
def __init__(self, connection_string, openai_api_key):
self.log('criou')
self.openai_api_key = openai_api_key
self.schema_name = None
self.engine = create_engine(connection_string)
self.conn_wrapper = self.engine.connect()
self.connection = self.conn_wrapper.connection
self.log('connection opened')
self.context = {}
self.context["top_k"] = 3
self.tables_vector_store = None
self.example_selector = None
self.chain_model = None
self.example_prompt = None
self.create_chain_model()

def create_examples_table(self):
sql = """
CREATE TABLE IF NOT EXISTS sqlzilla.examples (
id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
prompt VARCHAR(255) NOT NULL,
query VARCHAR(255) NOT NULL,
schema_name VARCHAR(255) NOT NULL
);
"""
self.execute_query(sql)

def get_examples(self):
sql = "SELECT prompt, query FROM sqlzilla.examples WHERE schema_name = %s"
self.log('sql: ' + sql)
self.log('params: ' + str([self.schema_name]))
rows = self.execute_query(sql, [self.schema_name])
examples = [{
"input": row[0],
"query": row[1]
} for row in rows]
return examples

def add_example(self, prompt, query):
sql = "INSERT INTO sqlzilla.examples (prompt, query, schema_name) VALUES (%s, %s, %s)"
self.execute_query(sql, [prompt, query, self.schema_name])

def __del__(self):
self.log('deletou')
if not self.connection is None:
self.log('connection closed')
self.connection.close()
if not self.engine is None:
self.engine.dispose()

def log(self, msg):
import os
os.write(1, f"{msg}\n".encode())

def get_table_definitions_array(self, schema, table=None):
cursor = self.connection.cursor()

# Base query to get columns information
query = """
SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = %s
"""

# Parameters for the query
params = [schema]

# Adding optional filters
if table:
query += " AND TABLE_NAME = %s"
params.append(table)

# Execute the query
cursor.execute(query, params)

# Fetch the results
rows = cursor.fetchall()

# Process the results to generate the table definition(s)
table_definitions = {}
for row in rows:
table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row
if table_name not in table_definitions:
table_definitions[table_name] = []
table_definitions[table_name].append({
"column_name": column_name,
"column_type": column_type,
"is_nullable": is_nullable,
"column_default": column_default,
"column_key": column_key,
"extra": extra
})

primary_keys = {}

# Build the output string
result = []
for table_name, columns in table_definitions.items():
table_def = f"CREATE TABLE {schema}.{table_name} (\n"
column_definitions = []
for column in columns:
column_def = f" {column['column_name']} {column['column_type']}"
if column['is_nullable'] == "NO":
column_def += " NOT NULL"
if column['column_default'] is not None:
column_def += f" DEFAULT {column['column_default']}"
if column['extra']:
column_def += f" {column['extra']}"
column_definitions.append(column_def)
if table_name in primary_keys:
pk_def = f" PRIMARY KEY ({', '.join(primary_keys[table_name])})"
column_definitions.append(pk_def)
table_def += ",\n".join(column_definitions)
table_def += "\n);"
result.append(table_def)

return result

def get_table_definitions(self, schema, table=None):
return "\n\n".join(self.get_table_definitions_array(schema=schema, table=table))

def get_ids_from_string_array(self, array):
return [str(hashlib.md5(x.encode()).hexdigest()) for x in array]

def exists_in_db(self, collection_name, id):
schema_name = "SQLUser"

cursor = self.connection.cursor()
query = f"""
SELECT TOP 1 id
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = %s and TABLE_NAME = %s
"""
params = [schema_name, collection_name]
cursor.execute(query, params)
rows = cursor.fetchall()
if len(rows) == 0:
return False

del cursor, query, params, rows

cursor = self.connection.cursor()
query = f"""
SELECT TOP 1 id
FROM {collection_name}
WHERE id = %s
"""
params = [id]
cursor.execute(query, params)
rows = cursor.fetchall()
return len(rows) > 0

def filter_not_in_collection(self, collection_name, docs_array, ids_array):
filtered = [x for x in zip(docs_array, ids_array) if not self.exists_in_db(collection_name, x[1])]
return list(zip(*filtered)) or ([], [])

def schema_context_management(self, schema):
if self.tables_vector_store is None:
table_def = self.get_table_definitions_array(schema)
self.table_df = pd.DataFrame(data=table_def, columns=["col_def"])
self.table_df["id"] = self.table_df.index + 1
loader = DataFrameLoader(self.table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
self.tables_docs = text_splitter.split_documents(documents)
self.log('schema_name: ' + str(self.schema_name))
collection_name_tables = "sql_tables_"+self.schema_name
new_tables_docs, tables_docs_ids = self.filter_not_in_collection(
collection_name_tables,
self.tables_docs,
self.get_ids_from_string_array([x.page_content for x in self.tables_docs])
)
self.tables_docs_ids = tables_docs_ids
self.tables_vector_store = IRISVector.from_documents(
embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key),
documents = self.tables_docs,
connection=self.conn_wrapper,
collection_name=collection_name_tables,
ids=self.tables_docs_ids
)

if self.example_selector is None:
examples = self.get_examples()
collection_name_examples = "sql_samples_"+self.schema_name
new_sql_samples, sql_samples_ids = self.filter_not_in_collection(
collection_name_examples,
examples,
self.get_ids_from_string_array([x['input'] for x in examples])
)
self.example_selector = MySemanticSimilarityExampleSelector.from_examples(
new_sql_samples,
OpenAIEmbeddings(openai_api_key=self.openai_api_key),
IRISVector,
k=5,
input_keys=["input"],
connection=self.conn_wrapper,
collection_name=collection_name_examples,
ids=sql_samples_ids
)

def create_chain_model(self):
if not self.chain_model is None:
return self.chain_model

iris_sql_template = """
You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".
Use double quotes to delimit columns identifiers.
Return just plain SQL; don't apply any kind of formatting.
"""
tables_prompt_template = """
Only use the following tables:
{table_info}
"""
prompt_sql_few_shots_template = """
Below are a number of examples of questions and their corresponding SQL queries.
{examples_value}
"""
example_prompt_template = "User input: {input}\nSQL query: {query}"
example_prompt = PromptTemplate.from_template(example_prompt_template)
self.example_prompt = example_prompt

user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string()
prompt = (
ChatPromptTemplate.from_messages([("system", iris_sql_template)])
+ ChatPromptTemplate.from_messages([("system", tables_prompt_template)])
+ ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
+ ChatPromptTemplate.from_messages([("human", user_prompt)])
)

model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key)
output_parser = StrOutputParser()
self.chain_model = prompt | model | output_parser

def prompt(self, input):
self.context["input"] = input

relevant_tables_docs = self.tables_vector_store.similarity_search(input)
self.log('relevant_tables_docs: ' + str(relevant_tables_docs))
relevant_tables_docs_with_score = self.tables_vector_store.similarity_search_with_score(input)
self.log('relevant_tables_docs_with_score: ' + str(relevant_tables_docs_with_score))
relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
indices = self.table_df["id"].isin(relevant_tables_docs_indices)
relevant_tables_array = [x for x in self.table_df[indices]["col_def"]]
self.context["table_info"] = "\n\n".join(relevant_tables_array)

examples_value = self.example_selector.select_examples({"input": self.context["input"]})
self.log('examples_value: ' + str(examples_value))
examples_value_with_score = self.example_selector.select_examples_with_score({"input": self.context["input"]})
self.log('examples_value_with_score: ' + str(examples_value_with_score))
self.context["examples_value"] = "\n\n".join([
self.example_prompt.invoke(x).to_string() for x in examples_value
])

self.log('context: ' + str(self.context))

response = self.create_chain_model().invoke({
"top_k": self.context["top_k"],
"table_info": self.context["table_info"],
"examples_value": self.context["examples_value"],
"input": self.context["input"]
})
return response

def execute_query(self, query, params=None):
cursor = self.connection.cursor()
# Execute the query
cursor.execute(query, params)

if re.search(r"\s*SELECT\s+", query, re.IGNORECASE):
# Fetch the results
return cursor.fetchall()
elif re.search(r"\s*INSERT\s+", query, re.IGNORECASE):
self.connection.commit()
elif re.search(r"\s*UPDATE\s+", query, re.IGNORECASE):
self.connection.commit()
elif re.search(r"\s*DELETE\s+", query, re.IGNORECASE):
self.connection.commit()
return None

class MySemanticSimilarityExampleSelector(SemanticSimilarityExampleSelector):

def select_examples_with_score(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select examples based on semantic similarity.
Args:
input_variables: The input variables to use for search.
Returns:
The selected examples.
"""
# Get the docs with the highest similarity.
vectorstore_kwargs = self.vectorstore_kwargs or {}
example_docs_with_score = self.vectorstore.similarity_search_with_score(
self._example_to_text(input_variables, self.input_keys),
k=self.k,
**vectorstore_kwargs,
)
example_docs = [x[0] for x in example_docs_with_score]
scores = [x[1] for x in example_docs_with_score]
examples = self._documents_to_examples(example_docs)
return list(zip(examples, scores))

0 comments on commit 333a68c

Please sign in to comment.