Skip to content

Commit

Permalink
refine nl2sql tiny (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
aero-xi authored Sep 14, 2024
1 parent dd9a192 commit e1dd9d6
Showing 1 changed file with 31 additions and 58 deletions.
89 changes: 31 additions & 58 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def retrieve_with_metadata(
query_bundle.query_str = self._limit_check(query_bundle.query_str)
logger.info(f"Limited SQL query: {query_bundle.query_str}")

# set timeout to 10s
# set timeout to 10s
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(10) # start
Expand Down Expand Up @@ -367,22 +366,9 @@ def retrieve_with_metadata(
metadata,
) = self._sql_retriever.retrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":

# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
# metadata,
# ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
# logger.info(
# f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
# )
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":

# new_sql_query_str = self._sql_query_modification(sql_query_str)
Expand All @@ -399,31 +385,38 @@ def retrieve_with_metadata(
logger.info(f"async error info: {e}\n")

new_sql_query_str = self._sql_query_modification(sql_query_str)
(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# err_node = TextNode(text=f"Error: {e!s}")
# logger.info(f"async error_node info: {err_node}\n")
# retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)]
# metadata = {}
# else:
# raise

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables
# 如果找到table,生成新的sql_query
if new_sql_query_str != sql_query_str:
(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 没有找到table,新旧sql_query一样,不再通过_sql_retriever执行,直接retrieved_nodes
else:
logger.info(f"[{new_sql_query_str}] is not even a SQL")
retrieved_nodes = [
NodeWithScore(
node=TextNode(
text=new_sql_query_str,
metadata={
"query_code_instruction": new_sql_query_str,
"generated_query_code_instruction": sql_query_str,
"query_output": "",
"invalid_flag": 1,
},
),
score=1.0,
),
]
metadata = {}

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
Expand Down Expand Up @@ -466,7 +459,6 @@ async def aretrieve_with_metadata(
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
Expand Down Expand Up @@ -522,15 +514,6 @@ async def aretrieve_with_metadata(
),
]
metadata = {}
# err_node = TextNode(text=f"Error: {e!s}")
# logger.info(f"async error_node info: {err_node}\n")
# retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)]
# metadata = {}
# else:
# raise
# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
Expand All @@ -545,13 +528,6 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection.append(table)
return table_collection

def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection = list()
for table in table_list:
if table.lower() in sql_query.lower():
table_collection.append(table)
return table_collection

def _sql_query_modification(self, sql_query_str):
table_pattern = r"FROM\s+(\w+)"
match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL)
Expand All @@ -563,9 +539,6 @@ def _sql_query_modification(self, sql_query_str):
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
logger.info("No table is matched")
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
logger.info("No table is matched")

return new_sql_query_str

Expand Down

0 comments on commit e1dd9d6

Please sign in to comment.