From 4b8c027af18aafa968403c0ae0f920c65c3a8f06 Mon Sep 17 00:00:00 2001 From: userpj Date: Thu, 14 Nov 2024 15:31:44 +0800 Subject: [PATCH] =?UTF-8?q?chatflow=E5=B7=A5=E4=BD=9C=E6=B5=81=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E5=AF=B9=E8=AF=9D=E9=AA=8C=E8=AF=81=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0'=E4=BF=A1=E6=81=AF=E6=94=B6=E9=9B=86=E8=8A=82?= =?UTF-8?q?=E7=82=B9'=E5=9B=9E=E5=A4=8D=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../appbuilder_client/appbuilder_client.py | 64 ++++++++------- .../console/appbuilder_client/data_class.py | 40 +++++++++- .../tests/test_appbuilder_client_chatflow.py | 79 +++++++++++++++++++ 3 files changed, 150 insertions(+), 33 deletions(-) create mode 100644 python/tests/test_appbuilder_client_chatflow.py diff --git a/python/core/console/appbuilder_client/appbuilder_client.py b/python/core/console/appbuilder_client/appbuilder_client.py index 3442318b9..f4d6b1188 100644 --- a/python/core/console/appbuilder_client/appbuilder_client.py +++ b/python/core/console/appbuilder_client/appbuilder_client.py @@ -26,6 +26,7 @@ from appbuilder.utils.logger_util import logger from appbuilder.utils.trace.tracer_wrapper import client_run_trace, client_tool_trace + @deprecated(reason="use describe_apps instead") @client_tool_trace def get_app_list( @@ -72,13 +73,14 @@ def get_app_list( out = resp.data return out + @client_tool_trace def describe_apps( - marker: Optional[str]=None, - maxKeys: int=10, + marker: Optional[str] = None, + maxKeys: int = 10, secret_key: Optional[str] = None, gateway: Optional[str] = None -)-> list[data_class.AppOverview]: +) -> list[data_class.AppOverview]: """ 该接口查询用户下状态为已发布的应用列表 @@ -144,11 +146,11 @@ class AppBuilderClient(Component): r""" AppBuilderClient 组件支持调用在[百度智能云千帆AppBuilder](https://cloud.baidu.com/product/AppBuilder)平台上 构建并发布的智能体应用,具体包括创建会话、上传文档、运行对话等。 - + Examples: - + .. code-block:: python - + import appbuilder # 请前往千帆AppBuilder官网创建密钥,流程详见:https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5 os.environ["APPBUILDER_TOKEN"] = '...' @@ -160,15 +162,15 @@ class AppBuilderClient(Component): message = client.run(conversation_id, "今天你好吗?") # 打印对话结果 print(message.content) - + """ def __init__(self, app_id: str, **kwargs): r"""初始化智能体应用 - + Args: app_id (str: 必须) : 应用唯一ID - + Returns: response (obj: `AppBuilderClient`): 智能体实例 """ @@ -186,13 +188,13 @@ def create_conversation(self) -> str: r"""创建会话并返回会话ID 会话ID在服务端用于上下文管理、绑定会话文档等,如需开始新的会话,请创建并使用新的会话ID - + Args: 无 - + Returns: response (str): 唯一会话ID - + """ headers = self.http_client.auth_header_v2() headers["Content-Type"] = "application/json" @@ -208,19 +210,20 @@ def create_conversation(self) -> str: @client_tool_trace def upload_local_file(self, conversation_id, local_file_path: str) -> str: r"""上传文件并将文件与会话ID进行绑定,后续可使用该文件ID进行对话,目前仅支持上传xlsx、jsonl、pdf、png等文件格式 - + 该接口用于在对话中上传文件供大模型处理,文件的有效期为7天并且不超过对话的有效期。一次只能上传一个文件。 Args: conversation_id (str) : 会话ID local_file_path (str) : 本地文件路径 - + Returns: response (str): 唯一文件ID - + """ if len(conversation_id) == 0: - raise ValueError("conversation_id is empty, you can run self.create_conversation to get a conversation_id") + raise ValueError( + "conversation_id is empty, you can run self.create_conversation to get a conversation_id") filepath = os.path.abspath(local_file_path) if not os.path.exists(filepath): @@ -250,10 +253,11 @@ def run(self, conversation_id: str, tool_outputs: list[data_class.ToolOutput] = None, tool_choice: data_class.ToolChoice = None, end_user_id: str = None, + action: data_class.Action = None, **kwargs ) -> Message: r"""运行智能体应用 - + Args: query (str): query内容 conversation_id (str): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话 @@ -264,7 +268,7 @@ def run(self, conversation_id: str, tool_choice(data_class.ToolChoice): 控制大模型使用组件的方式,默认为None end_user_id (str): 用户ID,用于区分不同用户 kwargs: 其他参数 - + Returns: message (Message): 对话结果,一个Message对象,使用message.content获取内容。 """ @@ -275,7 +279,8 @@ def run(self, conversation_id: str, ) if query == "" and (tool_outputs is None or len(tool_outputs) == 0): - raise ValueError("AppBuilderClient Run API: query and tool_outputs cannot both be empty") + raise ValueError( + "AppBuilderClient Run API: query and tool_outputs cannot both be empty") req = data_class.AppBuilderClientRequest( app_id=self.app_id, @@ -287,6 +292,7 @@ def run(self, conversation_id: str, tool_outputs=tool_outputs, tool_choice=tool_choice, end_user_id=end_user_id, + action=action, ) headers = self.http_client.auth_header_v2() @@ -308,13 +314,13 @@ def run(self, conversation_id: str, return Message(content=out) def run_with_handler(self, - conversation_id: str, - query: str = "", - file_ids: list = [], - tools: list[data_class.Tool] = None, - stream: bool = False, - event_handler = None, - **kwargs): + conversation_id: str, + query: str = "", + file_ids: list = [], + tools: list[data_class.Tool] = None, + stream: bool = False, + event_handler=None, + **kwargs): r"""运行智能体应用,并通过事件处理器处理事件 Args: @@ -374,11 +380,11 @@ class AgentBuilder(AppBuilderClient): r"""AgentBuilder是继承自AppBuilderClient的一个子类,用于构建和管理智能体应用。 支持调用在[百度智能云千帆AppBuilder](https://cloud.baidu.com/product/AppBuilder)平台上 构建并发布的智能体应用,具体包括创建会话、上传文档、运行对话等。 - + Examples: - + .. code-block:: python - + import appbuilder # 请前往千帆AppBuilder官网创建密钥,流程详见:https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5 os.environ["APPBUILDER_TOKEN"] = '...' diff --git a/python/core/console/appbuilder_client/data_class.py b/python/core/console/appbuilder_client/data_class.py index b19174f25..2b3e8a455 100644 --- a/python/core/console/appbuilder_client/data_class.py +++ b/python/core/console/appbuilder_client/data_class.py @@ -23,21 +23,26 @@ class Function(BaseModel): description: str = Field(..., description="工具描述") parameters: dict = Field(..., description="工具参数, json_schema格式") + class Tool(BaseModel): type: str = "function" function: Function = Field(..., description="工具信息") + class ToolOutput(BaseModel): tool_call_id: str = Field(..., description="工具调用ID") output: str = Field(..., description="工具输出") + class FunctionCallDetail(BaseModel): name: str = Field(..., description="函数的名称") arguments: dict = Field(..., description="模型希望您传递给函数的参数") + class ToolCall(BaseModel): id: str = Field(..., description="工具调用ID") - type: str = Field("function", description="需要输出的工具调用的类型。就目前而言,这始终是function") + type: str = Field( + "function", description="需要输出的工具调用的类型。就目前而言,这始终是function") function: FunctionCallDetail = Field(..., description="函数定义") @@ -62,6 +67,25 @@ class ToolChoice(BaseModel): ) +class ActionInterruptEvent(BaseModel): + id: str = Field(..., description="要回复的'信息收集节点'中断事件ID") + type: str = Field(..., description="要回复的'信息收集节点'中断事件类型,当前仅chat") + + +class ActionParameters(BaseModel): + interrupt_event: ActionInterruptEvent = Field( + ..., description="要回复的'信息收集节点'中断事件") + + +class Action(BaseModel): + action_type: str = Field(..., + description="action类型,目前可用值'resume', 用于回复信息收集节点的消息") + parameters: ActionParameters = Field( + ..., + description="对话时要进行的特殊操作。如回复工作流agent中'信息收集节点'的消息。", + ) + + class AppBuilderClientRequest(BaseModel): """会话请求参数 属性: @@ -80,6 +104,7 @@ class AppBuilderClientRequest(BaseModel): tool_outputs: Optional[list[ToolOutput]] = None tool_choice: Optional[ToolChoice] = None end_user_id: Optional[str] = None + action: Optional[Action] = None class Usage(BaseModel): @@ -296,11 +321,13 @@ class CreateConversationResponse(BaseModel): class AppBuilderClientAppListRequest(BaseModel): - limit: int = Field(default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) + limit: int = Field( + default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) after: str = Field( default="", description="用于分页的游标。after 是一个应用的id,它定义了在列表中的位置。例如,如果你发出一个列表请求并收到 10个对象,以 app_id_123 结束,那么你后续的调用可以包含 after=app_id_123 以获取列表的下一页数据。") before: str = Field(default="", description="用于分页的游标。与after相反,填写它将获取前一页数据") + class AppOverview(BaseModel): id: str = Field("", description="应用ID") name: str = Field("", description="应用名称") @@ -312,14 +339,19 @@ class AppOverview(BaseModel): isPublished: Optional[bool] = Field(None, description="是否已发布") updateTime: Optional[int] = Field(None, description="更新时间。时间戳,单位秒") + class AppBuilderClientAppListResponse(BaseModel): request_id: str = Field("", description="请求ID") data: Optional[list[AppOverview]] = Field( [], description="应用概览列表") + class DescribeAppsRequest(BaseModel): - maxKeys: int = Field(default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) - marker: str = Field(default=None, description="用于分页的游标。marker 是应用的id,它定义了在列表中的位置。例如,如果你发出一个列表请求并收到 10个对象,以 app_id_123 开始,那么可以使用 marker=app_id_123 来获取列表的下一页数据") + maxKeys: int = Field( + default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) + marker: str = Field( + default=None, description="用于分页的游标。marker 是应用的id,它定义了在列表中的位置。例如,如果你发出一个列表请求并收到 10个对象,以 app_id_123 开始,那么可以使用 marker=app_id_123 来获取列表的下一页数据") + class DescribeAppsResponse(BaseModel): requestId: str = Field("", description="请求ID") diff --git a/python/tests/test_appbuilder_client_chatflow.py b/python/tests/test_appbuilder_client_chatflow.py new file mode 100644 index 000000000..a0e0446d8 --- /dev/null +++ b/python/tests/test_appbuilder_client_chatflow.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import appbuilder + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientChatflow(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "4403205e-fb83-4fac-96d8-943bdb63796f" + + def test_agent_builder_run(self): + # 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败 + """ + 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败 + + Args: + self (unittest.TestCase): unittest的TestCase对象 + + Raises: + None: 如果app_id不为空,则不会引发任何异常 + unittest.SkipTest (optional): 如果app_id为空,则跳过单测执行 + """ + if len(self.app_id) == 0: + self.skipTest("self.app_id is empty") + appbuilder.logger.setLevel("ERROR") + builder = appbuilder.AppBuilderClient(self.app_id) + conversation_id = builder.create_conversation() + msg = builder.run(conversation_id, "查天气", stream=True) + + interrupt_event_id = None + for ans in msg.content: + for event in ans.events: + if event.content_type == "chatflow_interrupt": + assert event.event_type == "chatflow" + interrupt_event_id = event.detail.get("interrupt_event_id") + + self.assertIsNotNone(interrupt_event_id) + msg2 = builder.run(conversation_id=conversation_id, + query="我先查个航班动态", stream=True, + action={"action_type": "resume", + "parameters": { + "interrupt_event": { + "id": interrupt_event_id, + "type": "chat" + } + }}) + publish_message = None + for ans2 in msg2.content: + for event in ans2.events: + if event.content_type == "publish_message": + publish_message = event.detail.get("message") + print(publish_message) + self.assertIsNotNone(publish_message) + + +if __name__ == "__main__": + unittest.main()