Skip to content

Commit

Permalink
synthesizer中加入table description (#296)
Browse files Browse the repository at this point in the history
* add description to synthesizer

* make lint

* fix bug

* fix tests
  • Loading branch information
aero-xi authored Dec 5, 2024
1 parent e355446 commit 65a5a1d
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/pai_rag/app/web/ui_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
SYN_GENERAL_PROMPTS = "给定一个输入问题,根据查询代码指令以及查询结果生成最终回复生成的回复语言需要与输入问题的语言保持一致。\n\n输入问题: {query_str} \n\nSQL 或 Python 查询代码指令(可选): {query_code_instruction}\n\n 查询结果: {query_output}\n\n 最终回复: "
SYN_GENERAL_PROMPTS = "给定一个输入问题,根据查询代码指令以及查询结果生成最终回复。要求:\n----------\n1.生成的回复语言需要与输入问题的语言保持一致。\n2.生成的回复需要关注数据表信息描述中可能存在的字段单位或其他补充信息。\n----------\n输入问题: {query_str} \n数据表信息描述: {db_schema} \nSQL 或 Python 查询代码指令(可选): {query_code_instruction}\n查询结果: {query_output}\n\n最终回复: "


# WELCOME_MESSAGE = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ async def empty_response_agenerator() -> AsyncGenerator[str, None]:
yield "Empty Response"


DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = (
"Given an input question, synthesize a response in Chinese from the query results.\n"
"Query: {query_str}\n\n"
"SQL or Python Code Instructions (optional):\n{query_code_instruction}\n\n"
"Code Query Output: {query_output}\n\n"
"Response: "
)

DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate(
DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL,
"给定一个输入问题,根据数据表信息描述、查询代码指令以及查询结果生成最终回复。\n"
"要求: \n"
"1.生成的回复语言需要与输入问题的语言保持一致。\n"
"2.生成的回复需要关注数据表信息描述中可能存在的字段单位或其他补充信息。\n"
"输入问题: {query_str} \n"
"数据表信息描述: {db_schema} \n"
"SQL 或 Python 查询代码指令(可选): {query_code_instruction} \n"
"查询结果: {query_output} \n\n"
"最终回复: \n\n"
)


Expand Down Expand Up @@ -88,6 +88,7 @@ def _update_prompts(self, prompts: PromptDictType) -> None:
async def aget_response(
self,
query_str: str,
db_schema: str,
retrieved_nodes: List[NodeWithScore],
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
Expand All @@ -110,6 +111,7 @@ async def aget_response(
response = await self._llm.apredict(
self._response_synthesis_prompt,
query_str=query_str,
db_schema=db_schema,
query_code_instruction=[
n.node.metadata["query_code_instruction"] for n in retrieved_nodes
], # sql or pandas query
Expand All @@ -120,6 +122,7 @@ async def aget_response(
response = await self._llm.astream(
self._response_synthesis_prompt,
query_str=query_str,
db_schema=db_schema,
query_code_instruction=[
n.node.metadata["query_code_instruction"] for n in retrieved_nodes
],
Expand All @@ -137,6 +140,7 @@ async def aget_response(
def get_response(
self,
query_str: str,
db_schema: str,
retrieved_nodes: List[NodeWithScore],
**kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
Expand All @@ -159,6 +163,7 @@ def get_response(
response = self._llm.predict(
self._response_synthesis_prompt,
query_str=query_str,
db_schema=db_schema,
query_code_instruction=[
n.node.metadata["query_code_instruction"] for n in retrieved_nodes
], # sql or pandas query
Expand All @@ -169,6 +174,7 @@ def get_response(
response = self._llm.stream(
self._response_synthesis_prompt,
query_str=query_str,
db_schema=db_schema,
query_code_instruction=[
n.node.metadata["query_code_instruction"] for n in retrieved_nodes
],
Expand All @@ -187,6 +193,7 @@ def get_response(
def synthesize(
self,
query: QueryType,
db_schema: str,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
Expand Down Expand Up @@ -228,6 +235,7 @@ def synthesize(
) as event:
response_str = self.get_response(
query_str=query.query_str,
db_schema=db_schema,
retrieved_nodes=nodes,
**response_kwargs,
)
Expand All @@ -251,6 +259,7 @@ def synthesize(
async def asynthesize(
self,
query: QueryType,
db_schema: str,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
Expand Down Expand Up @@ -291,6 +300,7 @@ async def asynthesize(
) as event:
response_str = await self.aget_response(
query_str=query.query_str,
db_schema=db_schema,
retrieved_nodes=nodes,
**response_kwargs,
)
Expand Down
30 changes: 21 additions & 9 deletions src/pai_rag/integrations/data_analysis/data_analysis_tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Tuple

from llama_index.core.prompts import PromptTemplate
from llama_index.core.base.base_query_engine import BaseQueryEngine
Expand Down Expand Up @@ -84,30 +84,40 @@ def _get_prompt_modules(self) -> PromptMixinType:
return {}

def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = self._retriever.retrieve(query_bundle)
return nodes
nodes = self._retriever._retrieve(query_bundle)
if isinstance(nodes, Tuple):
return nodes[0], nodes[1]
else:
return nodes, ""

async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = await self._retriever.aretrieve(query_bundle)
return nodes
nodes = await self._retriever._aretrieve(query_bundle)
if isinstance(nodes, Tuple):
return nodes[0], nodes[1]
else:
return nodes, ""

def synthesize(
self,
query_bundle: QueryBundle,
db_schema: str,
nodes: List[NodeWithScore],
) -> RESPONSE_TYPE:
return self._synthesizer.synthesize(
query=query_bundle,
db_schema=db_schema,
nodes=nodes,
)

async def asynthesize(
self,
query_bundle: QueryBundle,
db_schema: str,
nodes: List[NodeWithScore],
) -> RESPONSE_TYPE:
return await self._synthesizer.asynthesize(
query=query_bundle,
db_schema=db_schema,
nodes=nodes,
)

Expand All @@ -117,9 +127,10 @@ def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes = self.retrieve(query_bundle)
nodes, db_schema = self.retrieve(query_bundle)
response = self._synthesizer.synthesize(
query=query_bundle,
db_schema=db_schema,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
Expand All @@ -132,9 +143,10 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes = await self.aretrieve(query_bundle)
nodes, db_schema = await self.aretrieve(query_bundle)
response = await self._synthesizer.asynthesize(
query=query_bundle,
db_schema=db_schema,
nodes=nodes,
)

Expand All @@ -146,10 +158,10 @@ async def astream_query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
streaming = self._synthesizer._streaming
self._synthesizer._streaming = True

nodes = await self.aretrieve(query_bundle)
nodes, db_schema = await self.aretrieve(query_bundle)

stream_response = await self._synthesizer.asynthesize(
query=query_bundle, nodes=nodes
query=query_bundle, db_schema=db_schema, nodes=nodes
)
self._synthesizer._streaming = streaming

Expand Down
31 changes: 19 additions & 12 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from llama_index.core.service_context import ServiceContext
from llama_index.core.settings import (
Settings,
callback_manager_from_settings_or_context,
embed_model_from_settings_or_context,
llm_from_settings_or_context,
)
Expand Down Expand Up @@ -310,7 +309,7 @@ def inspect_db_connection(
return sql_database, tables, table_descriptions


class MyNLSQLRetriever(BaseRetriever, PromptMixin):
class MyNLSQLRetriever(PromptMixin):
"""Text-to-SQL Retriever.
Retrieves via text.
Expand Down Expand Up @@ -374,10 +373,10 @@ def __init__(
self._handle_sql_errors = handle_sql_errors
self._sql_only = sql_only
self._verbose = verbose
super().__init__(
callback_manager=callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
)
# super().__init__(
# callback_manager=callback_manager
# or callback_manager_from_settings_or_context(Settings, service_context)
# )

@classmethod
def from_config(
Expand Down Expand Up @@ -546,7 +545,11 @@ def retrieve_with_metadata(
# add query_tables into metadata
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
return retrieved_nodes, {
"sql_query": sql_query_str,
"schema_description": table_desc_str,
**metadata,
}

async def aretrieve_with_metadata(
self, str_or_query_bundle: QueryType
Expand Down Expand Up @@ -645,7 +648,11 @@ async def aretrieve_with_metadata(
# add query_tables into metadata
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
return retrieved_nodes, {
"sql_query": sql_query_str,
"schema_description": table_desc_str,
**metadata,
}

def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection = list()
Expand All @@ -672,13 +679,13 @@ def _sql_query_modification(self, query_tables: list, sql_query_str: str):

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes given query."""
retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
return retrieved_nodes
retrieved_nodes, metadata = self.retrieve_with_metadata(query_bundle)
return retrieved_nodes, metadata["schema_description"]

async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Async retrieve nodes given query."""
retrieved_nodes, _ = await self.aretrieve_with_metadata(query_bundle)
return retrieved_nodes
retrieved_nodes, metadata = await self.aretrieve_with_metadata(query_bundle)
return retrieved_nodes, metadata["schema_description"]

def _get_table_context(self, query_bundle: QueryBundle) -> str:
"""Get table context.
Expand Down
4 changes: 2 additions & 2 deletions tests/integrations/test_nl2pandas_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def test_data_analysis_synthesizer():
data_analysis_synthesizer = DataAnalysisSynthesizer()

res_get_response = data_analysis_synthesizer.get_response(
query_str=query, retrieved_nodes=retrieved_nodes
query_str=query, db_schema="", retrieved_nodes=retrieved_nodes
)

assert len(res_get_response) > 0

res_synthesize = data_analysis_synthesizer.synthesize(
query=query, nodes=retrieved_nodes
query=query, db_schema="", nodes=retrieved_nodes
)

assert len(res_synthesize.response) > 0
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/test_nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,6 @@ def test_nl2sql_retriever(db_connection):
tables=db_tables,
)

res = nl2sql_retriever.retrieve("找出体重大于10的宠物的数量")
res, _ = nl2sql_retriever._retrieve("找出体重大于10的宠物的数量")

assert res[0].score == 1

0 comments on commit 65a5a1d

Please sign in to comment.