diff --git a/src/pai_rag/app/web/ui_constants.py b/src/pai_rag/app/web/ui_constants.py index 886c928a..0bcb2e23 100644 --- a/src/pai_rag/app/web/ui_constants.py +++ b/src/pai_rag/app/web/ui_constants.py @@ -8,7 +8,7 @@ DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " -SYN_GENERAL_PROMPTS = "给定一个输入问题,根据查询代码指令以及查询结果生成最终回复,生成的回复语言需要与输入问题的语言保持一致。\n\n输入问题: {query_str} \n\nSQL 或 Python 查询代码指令(可选): {query_code_instruction}\n\n 查询结果: {query_output}\n\n 最终回复: " +SYN_GENERAL_PROMPTS = "给定一个输入问题,根据查询代码指令以及查询结果生成最终回复。要求:\n----------\n1.生成的回复语言需要与输入问题的语言保持一致。\n2.生成的回复需要关注数据表信息描述中可能存在的字段单位或其他补充信息。\n----------\n输入问题: {query_str} \n数据表信息描述: {db_schema} \nSQL 或 Python 查询代码指令(可选): {query_code_instruction}\n查询结果: {query_output}\n\n最终回复: " # WELCOME_MESSAGE = """ diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py b/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py index 0e201457..fb57f4b1 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_synthesizer.py @@ -35,16 +35,16 @@ async def empty_response_agenerator() -> AsyncGenerator[str, None]: yield "Empty Response" -DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = ( - "Given an input question, synthesize a response in Chinese from the query results.\n" - "Query: {query_str}\n\n" - "SQL or Python Code Instructions (optional):\n{query_code_instruction}\n\n" - "Code Query Output: {query_output}\n\n" - "Response: " -) - DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate( - DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL, + "给定一个输入问题,根据数据表信息描述、查询代码指令以及查询结果生成最终回复。\n" + "要求: \n" + "1.生成的回复语言需要与输入问题的语言保持一致。\n" + "2.生成的回复需要关注数据表信息描述中可能存在的字段单位或其他补充信息。\n" + "输入问题: {query_str} \n" + "数据表信息描述: {db_schema} \n" + "SQL 或 Python 查询代码指令(可选): {query_code_instruction} \n" + "查询结果: {query_output} \n\n" + "最终回复: \n\n" ) @@ -88,6 +88,7 @@ def _update_prompts(self, prompts: PromptDictType) -> None: async def aget_response( self, query_str: str, + db_schema: str, retrieved_nodes: List[NodeWithScore], **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: @@ -110,6 +111,7 @@ async def aget_response( response = await self._llm.apredict( self._response_synthesis_prompt, query_str=query_str, + db_schema=db_schema, query_code_instruction=[ n.node.metadata["query_code_instruction"] for n in retrieved_nodes ], # sql or pandas query @@ -120,6 +122,7 @@ async def aget_response( response = await self._llm.astream( self._response_synthesis_prompt, query_str=query_str, + db_schema=db_schema, query_code_instruction=[ n.node.metadata["query_code_instruction"] for n in retrieved_nodes ], @@ -137,6 +140,7 @@ async def aget_response( def get_response( self, query_str: str, + db_schema: str, retrieved_nodes: List[NodeWithScore], **kwargs: Any, ) -> RESPONSE_TEXT_TYPE: @@ -159,6 +163,7 @@ def get_response( response = self._llm.predict( self._response_synthesis_prompt, query_str=query_str, + db_schema=db_schema, query_code_instruction=[ n.node.metadata["query_code_instruction"] for n in retrieved_nodes ], # sql or pandas query @@ -169,6 +174,7 @@ def get_response( response = self._llm.stream( self._response_synthesis_prompt, query_str=query_str, + db_schema=db_schema, query_code_instruction=[ n.node.metadata["query_code_instruction"] for n in retrieved_nodes ], @@ -187,6 +193,7 @@ def get_response( def synthesize( self, query: QueryType, + db_schema: str, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, **response_kwargs: Any, @@ -228,6 +235,7 @@ def synthesize( ) as event: response_str = self.get_response( query_str=query.query_str, + db_schema=db_schema, retrieved_nodes=nodes, **response_kwargs, ) @@ -251,6 +259,7 @@ def synthesize( async def asynthesize( self, query: QueryType, + db_schema: str, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, **response_kwargs: Any, @@ -291,6 +300,7 @@ async def asynthesize( ) as event: response_str = await self.aget_response( query_str=query.query_str, + db_schema=db_schema, retrieved_nodes=nodes, **response_kwargs, ) diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py index cd109b56..2f4d8637 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Tuple from llama_index.core.prompts import PromptTemplate from llama_index.core.base.base_query_engine import BaseQueryEngine @@ -84,30 +84,40 @@ def _get_prompt_modules(self) -> PromptMixinType: return {} def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - nodes = self._retriever.retrieve(query_bundle) - return nodes + nodes = self._retriever._retrieve(query_bundle) + if isinstance(nodes, Tuple): + return nodes[0], nodes[1] + else: + return nodes, "" async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - nodes = await self._retriever.aretrieve(query_bundle) - return nodes + nodes = await self._retriever._aretrieve(query_bundle) + if isinstance(nodes, Tuple): + return nodes[0], nodes[1] + else: + return nodes, "" def synthesize( self, query_bundle: QueryBundle, + db_schema: str, nodes: List[NodeWithScore], ) -> RESPONSE_TYPE: return self._synthesizer.synthesize( query=query_bundle, + db_schema=db_schema, nodes=nodes, ) async def asynthesize( self, query_bundle: QueryBundle, + db_schema: str, nodes: List[NodeWithScore], ) -> RESPONSE_TYPE: return await self._synthesizer.asynthesize( query=query_bundle, + db_schema=db_schema, nodes=nodes, ) @@ -117,9 +127,10 @@ def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: - nodes = self.retrieve(query_bundle) + nodes, db_schema = self.retrieve(query_bundle) response = self._synthesizer.synthesize( query=query_bundle, + db_schema=db_schema, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) @@ -132,9 +143,10 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: - nodes = await self.aretrieve(query_bundle) + nodes, db_schema = await self.aretrieve(query_bundle) response = await self._synthesizer.asynthesize( query=query_bundle, + db_schema=db_schema, nodes=nodes, ) @@ -146,10 +158,10 @@ async def astream_query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: streaming = self._synthesizer._streaming self._synthesizer._streaming = True - nodes = await self.aretrieve(query_bundle) + nodes, db_schema = await self.aretrieve(query_bundle) stream_response = await self._synthesizer.asynthesize( - query=query_bundle, nodes=nodes + query=query_bundle, db_schema=db_schema, nodes=nodes ) self._synthesizer._streaming = streaming diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index cc205166..ebfec1d4 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -36,7 +36,6 @@ from llama_index.core.service_context import ServiceContext from llama_index.core.settings import ( Settings, - callback_manager_from_settings_or_context, embed_model_from_settings_or_context, llm_from_settings_or_context, ) @@ -310,7 +309,7 @@ def inspect_db_connection( return sql_database, tables, table_descriptions -class MyNLSQLRetriever(BaseRetriever, PromptMixin): +class MyNLSQLRetriever(PromptMixin): """Text-to-SQL Retriever. Retrieves via text. @@ -374,10 +373,10 @@ def __init__( self._handle_sql_errors = handle_sql_errors self._sql_only = sql_only self._verbose = verbose - super().__init__( - callback_manager=callback_manager - or callback_manager_from_settings_or_context(Settings, service_context) - ) + # super().__init__( + # callback_manager=callback_manager + # or callback_manager_from_settings_or_context(Settings, service_context) + # ) @classmethod def from_config( @@ -546,7 +545,11 @@ def retrieve_with_metadata( # add query_tables into metadata retrieved_nodes[0].metadata["query_tables"] = query_tables - return retrieved_nodes, {"sql_query": sql_query_str, **metadata} + return retrieved_nodes, { + "sql_query": sql_query_str, + "schema_description": table_desc_str, + **metadata, + } async def aretrieve_with_metadata( self, str_or_query_bundle: QueryType @@ -645,7 +648,11 @@ async def aretrieve_with_metadata( # add query_tables into metadata retrieved_nodes[0].metadata["query_tables"] = query_tables - return retrieved_nodes, {"sql_query": sql_query_str, **metadata} + return retrieved_nodes, { + "sql_query": sql_query_str, + "schema_description": table_desc_str, + **metadata, + } def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: table_collection = list() @@ -672,13 +679,13 @@ def _sql_query_modification(self, query_tables: list, sql_query_str: str): def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve nodes given query.""" - retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle) - return retrieved_nodes + retrieved_nodes, metadata = self.retrieve_with_metadata(query_bundle) + return retrieved_nodes, metadata["schema_description"] async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Async retrieve nodes given query.""" - retrieved_nodes, _ = await self.aretrieve_with_metadata(query_bundle) - return retrieved_nodes + retrieved_nodes, metadata = await self.aretrieve_with_metadata(query_bundle) + return retrieved_nodes, metadata["schema_description"] def _get_table_context(self, query_bundle: QueryBundle) -> str: """Get table context. diff --git a/tests/integrations/test_nl2pandas_retriever.py b/tests/integrations/test_nl2pandas_retriever.py index 7d5055ef..b05db905 100644 --- a/tests/integrations/test_nl2pandas_retriever.py +++ b/tests/integrations/test_nl2pandas_retriever.py @@ -69,13 +69,13 @@ def test_data_analysis_synthesizer(): data_analysis_synthesizer = DataAnalysisSynthesizer() res_get_response = data_analysis_synthesizer.get_response( - query_str=query, retrieved_nodes=retrieved_nodes + query_str=query, db_schema="", retrieved_nodes=retrieved_nodes ) assert len(res_get_response) > 0 res_synthesize = data_analysis_synthesizer.synthesize( - query=query, nodes=retrieved_nodes + query=query, db_schema="", nodes=retrieved_nodes ) assert len(res_synthesize.response) > 0 diff --git a/tests/integrations/test_nl2sql_retriever.py b/tests/integrations/test_nl2sql_retriever.py index 8f9d3c81..7ab9d584 100644 --- a/tests/integrations/test_nl2sql_retriever.py +++ b/tests/integrations/test_nl2sql_retriever.py @@ -106,6 +106,6 @@ def test_nl2sql_retriever(db_connection): tables=db_tables, ) - res = nl2sql_retriever.retrieve("找出体重大于10的宠物的数量") + res, _ = nl2sql_retriever._retrieve("找出体重大于10的宠物的数量") assert res[0].score == 1