From 5d33bdf067eef27d107dac9c41406566ede005d1 Mon Sep 17 00:00:00 2001 From: userpj Date: Wed, 20 Nov 2024 19:57:42 +0800 Subject: [PATCH] =?UTF-8?q?chatflow=E5=B7=A5=E4=BD=9C=E6=B5=81=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E5=AF=B9=E8=AF=9D=E9=AA=8C=E8=AF=81=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0'=E4=BF=A1=E6=81=AF=E6=94=B6=E9=9B=86=E8=8A=82?= =?UTF-8?q?=E7=82=B9'=E5=9B=9E=E5=A4=8D=E5=8A=9F=E8=83=BD=20(#601)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chatflow工作流应用对话验证,增加'信息收集节点'回复功能 --- go/appbuilder/app_builder_client.go | 70 +++++-- go/appbuilder/app_builder_client_data.go | 84 ++++++-- go/appbuilder/app_builder_client_test.go | 102 +++++++++- .../AppBuilderClientRunRequest.java | 76 +++++++- .../model/appbuilderclient/Event.java | 2 + .../model/appbuilderclient/EventContent.java | 11 ++ .../appbuilder/AppBuilderClientTest.java | 42 ++++ .../assistant/threads/messages/messages.py | 1 - .../appbuilder_client/appbuilder_client.py | 131 ++++++++++--- .../console/appbuilder_client/data_class.py | 52 ++++- .../appbuilder_client/event_handler.py | 182 ++++++++++++------ .../tests/test_appbuilder_client_chatflow.py | 108 +++++++++++ ...ppbuilder_client_chatflow_event_handler.py | 128 ++++++++++++ ...uilder_client_chatflow_event_handler_v2.py | 86 +++++++++ 14 files changed, 947 insertions(+), 128 deletions(-) create mode 100644 python/tests/test_appbuilder_client_chatflow.py create mode 100644 python/tests/test_appbuilder_client_chatflow_event_handler.py create mode 100644 python/tests/test_appbuilder_client_chatflow_event_handler_v2.py diff --git a/go/appbuilder/app_builder_client.go b/go/appbuilder/app_builder_client.go index 1040c4227..c2c6921c8 100644 --- a/go/appbuilder/app_builder_client.go +++ b/go/appbuilder/app_builder_client.go @@ -26,6 +26,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strconv" "time" ) @@ -250,17 +251,26 @@ func (t *AppBuilderClient) UploadLocalFile(conversationID string, filePath strin return val.(string), nil } -func (t *AppBuilderClient) Run(conversationID string, query string, fileIDS []string, stream bool) (AppBuilderClientIterator, error) { - if len(conversationID) == 0 { - return nil, errors.New("conversationID mustn't be empty") +func (t *AppBuilderClient) Run(param ...interface{}) (AppBuilderClientIterator, error) { + if len(param) == 0 { + return nil, errors.New("no arguments provided") + } + var err error + var req AppBuilderClientRunRequest + + if reflect.TypeOf(param[0]) == reflect.TypeOf(AppBuilderClientRunRequest{}) { + req = param[0].(AppBuilderClientRunRequest) + } else { + req, err = t.buildAppBuilderClientRunRequest(param...) + if err != nil { + return nil, err + } } - m := map[string]any{ - "app_id": t.appID, - "conversation_id": conversationID, - "query": query, - "file_ids": fileIDS, - "stream": stream, + + if len(req.ConversationID) == 0 { + return nil, errors.New("conversationID mustn't be empty") } + request := http.Request{} serviceURL, err := t.sdkConfig.ServiceURLV2("/app/conversation/runs") @@ -268,33 +278,63 @@ func (t *AppBuilderClient) Run(conversationID string, query string, fileIDS []st return nil, err } + header := t.sdkConfig.AuthHeaderV2() request.URL = serviceURL request.Method = "POST" - header := t.sdkConfig.AuthHeaderV2() header.Set("Content-Type", "application/json") request.Header = header - data, _ := json.Marshal(m) + data, _ := json.Marshal(req) request.Body = NopCloser(bytes.NewReader(data)) - request.ContentLength = int64(len(data)) // 手动设置长度 + request.ContentLength = int64(len(data)) t.sdkConfig.BuildCurlCommand(&request) - resp, err := t.client.Do(&request) if err != nil { return nil, err } - requestID, err := checkHTTPResponse(resp) if err != nil { return nil, fmt.Errorf("requestID=%s, err=%v", requestID, err) } r := NewSSEReader(1024*1024, bufio.NewReader(resp.Body)) - if stream { + if req.Stream { return &AppBuilderClientStreamIterator{requestID: requestID, r: r, body: resp.Body}, nil } return &AppBuilderClientOnceIterator{body: resp.Body}, nil } +func (t *AppBuilderClient) buildAppBuilderClientRunRequest(param ...interface{}) (AppBuilderClientRunRequest, error) { + conversationID, ok := param[0].(string) + if !ok { + return AppBuilderClientRunRequest{}, errors.New("conversationID must be string type") + } + query, ok := param[1].(string) + if !ok { + return AppBuilderClientRunRequest{}, errors.New("query must be string type") + } + + var fileIDS []string + if param[2] != nil { + fileIDS, ok = param[2].([]string) + if !ok { + fileIDS = nil + } + } + + stream, ok := param[3].(bool) + if !ok { + stream = false + } + + return AppBuilderClientRunRequest{ + AppID: t.appID, + ConversationID: conversationID, + Query: query, + Stream: stream, + FileIDs: fileIDS, + }, nil +} + func (t *AppBuilderClient) RunWithToolCall(req AppBuilderClientRunRequest) (AppBuilderClientIterator, error) { if len(req.ConversationID) == 0 { return nil, errors.New("conversationID mustn't be empty") diff --git a/go/appbuilder/app_builder_client_data.go b/go/appbuilder/app_builder_client_data.go index 288c24805..7ca8c5ac0 100644 --- a/go/appbuilder/app_builder_client_data.go +++ b/go/appbuilder/app_builder_client_data.go @@ -23,25 +23,33 @@ import ( ) const ( - CodeContentType = "code" - TextContentType = "text" - ImageContentType = "image" - RAGContentType = "rag" - FunctionCallContentType = "function_call" - AudioContentType = "audio" - VideoContentType = "video" - StatusContentType = "status" + CodeContentType = "code" + TextContentType = "text" + ImageContentType = "image" + RAGContentType = "rag" + FunctionCallContentType = "function_call" + AudioContentType = "audio" + VideoContentType = "video" + StatusContentType = "status" + ChatflowInterruptContentType = "chatflow_interrupt" + PublishMessageContentType = "publish_message" +) + +const ( + ChatflowEventType = "chatflow" ) var TypeToStruct = map[string]reflect.Type{ - CodeContentType: reflect.TypeOf(CodeDetail{}), - TextContentType: reflect.TypeOf(TextDetail{}), - ImageContentType: reflect.TypeOf(ImageDetail{}), - RAGContentType: reflect.TypeOf(RAGDetail{}), - FunctionCallContentType: reflect.TypeOf(FunctionCallDetail{}), - AudioContentType: reflect.TypeOf(AudioDetail{}), - VideoContentType: reflect.TypeOf(VideoDetail{}), - StatusContentType: reflect.TypeOf(StatusDetail{}), + CodeContentType: reflect.TypeOf(CodeDetail{}), + TextContentType: reflect.TypeOf(TextDetail{}), + ImageContentType: reflect.TypeOf(ImageDetail{}), + RAGContentType: reflect.TypeOf(RAGDetail{}), + FunctionCallContentType: reflect.TypeOf(FunctionCallDetail{}), + AudioContentType: reflect.TypeOf(AudioDetail{}), + VideoContentType: reflect.TypeOf(VideoDetail{}), + StatusContentType: reflect.TypeOf(StatusDetail{}), + ChatflowInterruptContentType: reflect.TypeOf(ChatflowInterruptDetail{}), + PublishMessageContentType: reflect.TypeOf(PublishMessageDetail{}), } type AppBuilderClientRunRequest struct { @@ -50,9 +58,11 @@ type AppBuilderClientRunRequest struct { Stream bool `json:"stream"` EndUserID *string `json:"end_user_id"` ConversationID string `json:"conversation_id"` + FileIDs []string `json:"file_ids"` Tools []Tool `json:"tools"` ToolOutputs []ToolOutput `json:"tool_outputs"` ToolChoice *ToolChoice `json:"tool_choice"` + Action *Action `json:"action"` } type Tool struct { @@ -81,6 +91,36 @@ type ToolChoiceFunction struct { Input map[string]interface{} `json:"input"` } +type Action struct { + ActionType string `json:"action_type"` + Paramters *ActionParamters `json:"parameters"` +} + +type ActionParamters struct { + InterruptEvent *ActionInterruptEvent `json:"interrupt_event"` +} + +type ActionInterruptEvent struct { + ID string `json:"id"` + Type string `json:"type"` +} + +func NewResumeAction(eventId string) *Action { + return NewAction("resume", eventId, "chat") +} + +func NewAction(actionType string, eventId string, eventType string) *Action { + return &Action{ + ActionType: actionType, + Paramters: &ActionParamters{ + InterruptEvent: &ActionInterruptEvent{ + ID: eventId, + Type: eventType, + }, + }, + } +} + type AgentBuilderRawResponse struct { RequestID string `json:"request_id"` Date string `json:"date"` @@ -122,7 +162,7 @@ type Event struct { EventType string ContentType string Usage Usage - Detail any // 将any替换为interface{} + Detail any ToolCalls []ToolCall } @@ -185,6 +225,16 @@ type VideoDetail struct { type StatusDetail struct{} +type ChatflowInterruptDetail struct { + InterruptEventID string `json:"interrupt_event_id"` + InterruptEventType string `json:"interrupt_event_type"` +} + +type PublishMessageDetail struct { + Message string `json:"message"` + MessageID string `json:"message_id"` +} + type DefaultDetail struct { URLS []string `json:"urls"` Files []string `json:"files"` diff --git a/go/appbuilder/app_builder_client_test.go b/go/appbuilder/app_builder_client_test.go index 000ba85aa..e6b817566 100644 --- a/go/appbuilder/app_builder_client_test.go +++ b/go/appbuilder/app_builder_client_test.go @@ -574,7 +574,7 @@ func TestAppBuilderClientRunToolChoice(t *testing.T) { input := make(map[string]any) input["city"] = "北京" end_user_id := "go_user_id_0" - i, err := client.RunWithToolCall(AppBuilderClientRunRequest{ + i, err := client.Run(AppBuilderClientRunRequest{ ConversationID: conversationID, AppID: appID, Query: "你能干什么", @@ -610,3 +610,103 @@ func TestAppBuilderClientRunToolChoice(t *testing.T) { t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") } } + +func TestAppBuilderClientRunChatflow(t *testing.T) { + var logBuffer bytes.Buffer + + os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG") + + config, err := NewSDKConfig("", "") + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new http client config failed: %v", err) + } + + appID := "4403205e-fb83-4fac-96d8-943bdb63796f" + client, err := NewAppBuilderClient(appID, config) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new AgentBuidler instance failed") + } + + conversationID, err := client.CreateConversation() + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("create conversation failed: %v", err) + } + + i, err := client.Run(AppBuilderClientRunRequest{ + ConversationID: conversationID, + AppID: appID, + Query: "查天气", + Stream: true, + }) + + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run failed:%v", err) + } + + var interruptId string + for answer, err := i.Next(); err == nil; answer, err = i.Next() { + for _, ev := range answer.Events { + if ev.ContentType == ChatflowInterruptContentType { + if ev.EventType != ChatflowEventType { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("event type error:%v", err) + } + + deatil := ev.Detail.(ChatflowInterruptDetail) + interruptId = deatil.InterruptEventID + break + } + } + } + + if len(interruptId) == 0 { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("interrupt id is empty") + } + + i2, err := client.Run(AppBuilderClientRunRequest{ + ConversationID: conversationID, + AppID: appID, + Query: "我先查个航班动态", + Stream: true, + Action: NewResumeAction(interruptId), + }) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run failed:%v", err) + } + + var message string + for answer, err := i2.Next(); err == nil; answer, err = i2.Next() { + for _, ev := range answer.Events { + if ev.ContentType == PublishMessageContentType { + if ev.EventType != ChatflowEventType { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("event type error:%v", err) + } + + detail := ev.Detail.(PublishMessageDetail) + message = detail.Message + break + } + } + } + if len(message) == 0 { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("message is empty") + } + fmt.Println(message) + + // 如果测试失败,则输出缓冲区中的日志 + if t.Failed() { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + fmt.Println(logBuffer.String()) + } else { // else 紧跟在右大括号后面 + // 测试通过,打印文件名和测试函数名 + t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java index a1148a0ee..b836f987d 100644 --- a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java +++ b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java @@ -19,6 +19,7 @@ public class AppBuilderClientRunRequest { private ToolOutput[] ToolOutputs; @SerializedName("tool_choice") private ToolChoice ToolChoice; + private Action action; public AppBuilderClientRunRequest() { } @@ -123,6 +124,14 @@ public void setToolChoice(ToolChoice toolChoice) { this.ToolChoice = toolChoice; } + public Action getAction() { + return action; + } + + public void setAction(Action action) { + this.action = action; + } + public static class Tool { private String type; private Function function; @@ -189,9 +198,9 @@ public static class ToolChoice { private String type; private Function function; - public ToolChoice(String type, Function function) { - this.type=type; - this.function=function; + public ToolChoice(String type, Function function) { + this.type = type; + this.function = function; } public String getType() { @@ -220,4 +229,65 @@ public Map getInput() { } } } + + public static class Action { + @SerializedName("action_type") + private String actionType; + private Parameters parameters; + + // 回复消息节点构造方法 + public static Action createAction(String interruptId) { + return createAction("resume", interruptId, "chat"); + } + + public static Action createAction(String actionType, String id, String type) { + Parameters.InterruptEvent interruptEvent = new Parameters.InterruptEvent(id, type); + Parameters parameters = new Parameters(interruptEvent); + return new Action(actionType, parameters); + } + + public Action(String actionType, Parameters parameters) { + this.actionType = actionType; + this.parameters = parameters; + } + + public String getActionType() { + return actionType; + } + + public Parameters getParameters() { + return parameters; + } + + public static class Parameters { + @SerializedName("interrupt_event") + private InterruptEvent interruptEvent; + + public Parameters(InterruptEvent interruptEvent) { + this.interruptEvent = interruptEvent; + } + + public InterruptEvent getInterruptEvent() { + return interruptEvent; + } + + public static class InterruptEvent { + private String id; + private String type; + + public InterruptEvent(String id, String type) { + this.id = id; + this.type = type; + } + + public String getId() { + return id; + } + + public String getType() { + return type; + } + } + } + } } diff --git a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/Event.java b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/Event.java index e2da38e61..fe7248974 100644 --- a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/Event.java +++ b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/Event.java @@ -4,6 +4,8 @@ import com.google.gson.annotations.SerializedName; public class Event { + public static final String ChatflowEventType = "chatflow"; + private String code; private String message; private String eventType; diff --git a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/EventContent.java b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/EventContent.java index 2ade51b70..4c6a672af 100644 --- a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/EventContent.java +++ b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/EventContent.java @@ -5,6 +5,17 @@ import java.util.Map; public class EventContent { + public static final String CodeContentType = "code"; + public static final String TextContentType = "text"; + public static final String ImageContentType = "image"; + public static final String RAGContentType = "rag"; + public static final String FunctionCallContentType = "function_call"; + public static final String AudioContentType = "audio"; + public static final String VideoContentType = "video"; + public static final String StatusContentType = "status"; + public static final String ChatflowInterruptContentType = "chatflow_interrupt"; + public static final String PublishMessageContentType = "publish_message"; + @SerializedName("event_code") private String eventCode; @SerializedName("event_message") diff --git a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java index 44b1764f4..cb33dc655 100644 --- a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java @@ -13,6 +13,8 @@ import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult; import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest; import com.baidubce.appbuilder.model.appbuilderclient.AppsDescribeRequest; +import com.baidubce.appbuilder.model.appbuilderclient.Event; +import com.baidubce.appbuilder.model.appbuilderclient.EventContent; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientRunRequest; import org.junit.Before; import org.junit.Test; @@ -21,12 +23,14 @@ public class AppBuilderClientTest { String appId; + String chatflowAppId; @Before public void setUp() { System.setProperty("APPBUILDER_TOKEN", System.getenv("APPBUILDER_TOKEN")); System.setProperty("APPBUILDER_LOGLEVEL", "DEBUG"); appId = "aa8af334-df27-4855-b3d1-0d249c61fc08"; + chatflowAppId = "4403205e-fb83-4fac-96d8-943bdb63796f"; } @Test @@ -120,4 +124,42 @@ public void AppBuilderClientRunToolChoiceTest() throws IOException, AppBuilderSe System.out.println(result); } } + + @Test + public void AppBuilderClientRunChatflowTest() throws IOException, AppBuilderServerException { + AppBuilderClient builder = new AppBuilderClient(chatflowAppId); + String conversationId = builder.createConversation(); + assertNotNull(conversationId); + AppBuilderClientRunRequest request = new AppBuilderClientRunRequest(chatflowAppId, conversationId, "查天气", true); + AppBuilderClientIterator itor = builder.run(request); + assertTrue(itor.hasNext()); + String interruptEventId = ""; + while (itor.hasNext()) { + AppBuilderClientResult result = itor.next(); + for (Event event : result.getEvents()) { + System.out.println(event.getContentType()); + if (event.getContentType().equals(EventContent.ChatflowInterruptContentType)) { + assertEquals(event.getEventType(), Event.ChatflowEventType); + interruptEventId = event.getDetail().get("interrupt_event_id").toString(); + } + } + } + + assert interruptEventId != null && !interruptEventId.isEmpty(); + AppBuilderClientRunRequest request2 = new AppBuilderClientRunRequest(chatflowAppId, conversationId, "我先查个航班动态", + true); + request2.setAction(AppBuilderClientRunRequest.Action.createAction(interruptEventId)); + AppBuilderClientIterator itor2 = builder.run(request2); + assertTrue(itor2.hasNext()); + String message = ""; + while (itor2.hasNext()) { + AppBuilderClientResult result2 = itor2.next(); + for (Event event : result2.getEvents()) { + if (event.getContentType().equals(EventContent.PublishMessageContentType)) { + message = event.getDetail().get("message").toString(); + } + } + } + assert message != null && !message.isEmpty(); + } } diff --git a/python/core/assistant/threads/messages/messages.py b/python/core/assistant/threads/messages/messages.py index 8038ef788..c06eaf5e6 100644 --- a/python/core/assistant/threads/messages/messages.py +++ b/python/core/assistant/threads/messages/messages.py @@ -243,4 +243,3 @@ def files(self, self._http_client.check_assistant_response(request_id, data) response = thread_type.AssistantMessageFilesResponse(**data) return response - \ No newline at end of file diff --git a/python/core/console/appbuilder_client/appbuilder_client.py b/python/core/console/appbuilder_client/appbuilder_client.py index 3442318b9..357580bef 100644 --- a/python/core/console/appbuilder_client/appbuilder_client.py +++ b/python/core/console/appbuilder_client/appbuilder_client.py @@ -16,6 +16,7 @@ import os import json import uuid +import queue from typing import Optional from appbuilder.core.component import Message, Component from appbuilder.core.console.appbuilder_client import data_class @@ -26,6 +27,7 @@ from appbuilder.utils.logger_util import logger from appbuilder.utils.trace.tracer_wrapper import client_run_trace, client_tool_trace + @deprecated(reason="use describe_apps instead") @client_tool_trace def get_app_list( @@ -72,13 +74,14 @@ def get_app_list( out = resp.data return out + @client_tool_trace def describe_apps( - marker: Optional[str]=None, - maxKeys: int=10, + marker: Optional[str] = None, + maxKeys: int = 10, secret_key: Optional[str] = None, gateway: Optional[str] = None -)-> list[data_class.AppOverview]: +) -> list[data_class.AppOverview]: """ 该接口查询用户下状态为已发布的应用列表 @@ -144,11 +147,11 @@ class AppBuilderClient(Component): r""" AppBuilderClient 组件支持调用在[百度智能云千帆AppBuilder](https://cloud.baidu.com/product/AppBuilder)平台上 构建并发布的智能体应用,具体包括创建会话、上传文档、运行对话等。 - + Examples: - + .. code-block:: python - + import appbuilder # 请前往千帆AppBuilder官网创建密钥,流程详见:https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5 os.environ["APPBUILDER_TOKEN"] = '...' @@ -160,15 +163,15 @@ class AppBuilderClient(Component): message = client.run(conversation_id, "今天你好吗?") # 打印对话结果 print(message.content) - + """ def __init__(self, app_id: str, **kwargs): r"""初始化智能体应用 - + Args: app_id (str: 必须) : 应用唯一ID - + Returns: response (obj: `AppBuilderClient`): 智能体实例 """ @@ -186,13 +189,13 @@ def create_conversation(self) -> str: r"""创建会话并返回会话ID 会话ID在服务端用于上下文管理、绑定会话文档等,如需开始新的会话,请创建并使用新的会话ID - + Args: 无 - + Returns: response (str): 唯一会话ID - + """ headers = self.http_client.auth_header_v2() headers["Content-Type"] = "application/json" @@ -208,19 +211,20 @@ def create_conversation(self) -> str: @client_tool_trace def upload_local_file(self, conversation_id, local_file_path: str) -> str: r"""上传文件并将文件与会话ID进行绑定,后续可使用该文件ID进行对话,目前仅支持上传xlsx、jsonl、pdf、png等文件格式 - + 该接口用于在对话中上传文件供大模型处理,文件的有效期为7天并且不超过对话的有效期。一次只能上传一个文件。 Args: conversation_id (str) : 会话ID local_file_path (str) : 本地文件路径 - + Returns: response (str): 唯一文件ID - + """ if len(conversation_id) == 0: - raise ValueError("conversation_id is empty, you can run self.create_conversation to get a conversation_id") + raise ValueError( + "conversation_id is empty, you can run self.create_conversation to get a conversation_id") filepath = os.path.abspath(local_file_path) if not os.path.exists(filepath): @@ -250,10 +254,11 @@ def run(self, conversation_id: str, tool_outputs: list[data_class.ToolOutput] = None, tool_choice: data_class.ToolChoice = None, end_user_id: str = None, + action: data_class.Action = None, **kwargs ) -> Message: r"""运行智能体应用 - + Args: query (str): query内容 conversation_id (str): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话 @@ -263,9 +268,10 @@ def run(self, conversation_id: str, tool_outputs(list[data_class.ToolOutput]): 工具输出列表,格式为list[ToolOutput], ToolOutputd内容为本地的工具执行结果,以自然语言/json dump str描述,默认为None tool_choice(data_class.ToolChoice): 控制大模型使用组件的方式,默认为None end_user_id (str): 用户ID,用于区分不同用户 + action(data_class.Action): 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 kwargs: 其他参数 - - Returns: + + Returns: message (Message): 对话结果,一个Message对象,使用message.content获取内容。 """ @@ -275,7 +281,8 @@ def run(self, conversation_id: str, ) if query == "" and (tool_outputs is None or len(tool_outputs) == 0): - raise ValueError("AppBuilderClient Run API: query and tool_outputs cannot both be empty") + raise ValueError( + "AppBuilderClient Run API: query and tool_outputs cannot both be empty") req = data_class.AppBuilderClientRequest( app_id=self.app_id, @@ -287,6 +294,7 @@ def run(self, conversation_id: str, tool_outputs=tool_outputs, tool_choice=tool_choice, end_user_id=end_user_id, + action=action, ) headers = self.http_client.auth_header_v2() @@ -308,13 +316,14 @@ def run(self, conversation_id: str, return Message(content=out) def run_with_handler(self, - conversation_id: str, - query: str = "", - file_ids: list = [], - tools: list[data_class.Tool] = None, - stream: bool = False, - event_handler = None, - **kwargs): + conversation_id: str, + query: str = "", + file_ids: list = [], + tools: list[data_class.Tool] = None, + stream: bool = False, + event_handler=None, + action=None, + **kwargs): r"""运行智能体应用,并通过事件处理器处理事件 Args: @@ -324,6 +333,8 @@ def run_with_handler(self, tools(list[data_class.Tools], 可选): 一个Tools组成的列表,其中每个Tools对应一个工具的配置, 默认为None stream (bool): 是否流式响应 event_handler (EventHandler): 事件处理器 + action(dataclass.Action) 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + kwargs: 其他参数 Returns: @@ -337,11 +348,71 @@ def run_with_handler(self, file_ids=file_ids, tools=tools, stream=stream, + action=action, **kwargs ) return event_handler + def run_multiple_dialog_with_handler(self, + conversation_id: str, + queries: iter = None, + file_ids: iter = None, + tools: iter = None, + stream: bool = False, + event_handler=None, + actions: iter = None, + **kwargs): + r"""运行智能体应用,并通过事件处理器处理事件 + + Args: + conversation_id (str): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话 + queries (iter): 查询字符串可迭代对象 + file_ids (iter): 文件ID列表 + tools(iter, 可选): 一个Tools组成的列表,其中每个Tools对应一个工具的配置, 默认为None + stream (bool): 是否流式响应 + event_handler (EventHandler): 事件处理器 + actions(iter) 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + + kwargs: 其他参数 + Returns: + EventHandler: 事件处理器 + """ + assert event_handler is not None, "event_handler is None" + assert queries is not None, "queries is None" + + iter_queries = iter(queries) + iter_file_ids = iter(file_ids) if file_ids else iter([]) + iter_tools = iter(tools) if tools else iter([]) + iter_actions = iter(actions) if actions else iter([]) + + for index, query in enumerate(iter_queries): + file_id = next(iter_file_ids, None) + tool = next(iter_tools, None) + action = next(iter_actions, None) + + if index == 0: + yield from self.run_with_handler( + conversation_id=conversation_id, + query=query, + file_ids=file_id, + tools=tool, + stream=stream, + event_handler=event_handler, + action=action, + **kwargs, + ) + else: + event_handler.new_dialog( + query=query, + file_ids=file_id, + tools=tool, + stream=stream, + action=action, + ) + yield event_handler + event_handler.reset_state() + @staticmethod def _iterate_events(request_id, events) -> data_class.AppBuilderClientAnswer: for event in events: @@ -374,11 +445,11 @@ class AgentBuilder(AppBuilderClient): r"""AgentBuilder是继承自AppBuilderClient的一个子类,用于构建和管理智能体应用。 支持调用在[百度智能云千帆AppBuilder](https://cloud.baidu.com/product/AppBuilder)平台上 构建并发布的智能体应用,具体包括创建会话、上传文档、运行对话等。 - + Examples: - + .. code-block:: python - + import appbuilder # 请前往千帆AppBuilder官网创建密钥,流程详见:https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5 os.environ["APPBUILDER_TOKEN"] = '...' diff --git a/python/core/console/appbuilder_client/data_class.py b/python/core/console/appbuilder_client/data_class.py index b19174f25..cbc06ff31 100644 --- a/python/core/console/appbuilder_client/data_class.py +++ b/python/core/console/appbuilder_client/data_class.py @@ -23,21 +23,26 @@ class Function(BaseModel): description: str = Field(..., description="工具描述") parameters: dict = Field(..., description="工具参数, json_schema格式") + class Tool(BaseModel): type: str = "function" function: Function = Field(..., description="工具信息") + class ToolOutput(BaseModel): tool_call_id: str = Field(..., description="工具调用ID") output: str = Field(..., description="工具输出") + class FunctionCallDetail(BaseModel): name: str = Field(..., description="函数的名称") arguments: dict = Field(..., description="模型希望您传递给函数的参数") + class ToolCall(BaseModel): id: str = Field(..., description="工具调用ID") - type: str = Field("function", description="需要输出的工具调用的类型。就目前而言,这始终是function") + type: str = Field( + "function", description="需要输出的工具调用的类型。就目前而言,这始终是function") function: FunctionCallDetail = Field(..., description="函数定义") @@ -62,6 +67,37 @@ class ToolChoice(BaseModel): ) +class ActionInterruptEvent(BaseModel): + id: str = Field(..., description="要回复的'信息收集节点'中断事件ID") + type: str = Field(..., description="要回复的'信息收集节点'中断事件类型,当前仅chat") + + +class ActionParameters(BaseModel): + interrupt_event: ActionInterruptEvent = Field( + ..., description="要回复的'信息收集节点'中断事件") + + +class Action(BaseModel): + action_type: str = Field(..., + description="action类型,目前可用值'resume', 用于回复信息收集节点的消息") + parameters: ActionParameters = Field( + ..., + description="对话时要进行的特殊操作。如回复工作流agent中'信息收集节点'的消息。", + ) + + @classmethod + def create_resume_action(cls, event_id): + return { + "action_type": "resume", + "parameters": { + "interrupt_event": { + "id": event_id, + "type": "chat" + } + } + } + + class AppBuilderClientRequest(BaseModel): """会话请求参数 属性: @@ -80,6 +116,7 @@ class AppBuilderClientRequest(BaseModel): tool_outputs: Optional[list[ToolOutput]] = None tool_choice: Optional[ToolChoice] = None end_user_id: Optional[str] = None + action: Optional[Action] = None class Usage(BaseModel): @@ -296,11 +333,13 @@ class CreateConversationResponse(BaseModel): class AppBuilderClientAppListRequest(BaseModel): - limit: int = Field(default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) + limit: int = Field( + default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) after: str = Field( default="", description="用于分页的游标。after 是一个应用的id,它定义了在列表中的位置。例如,如果你发出一个列表请求并收到 10个对象,以 app_id_123 结束,那么你后续的调用可以包含 after=app_id_123 以获取列表的下一页数据。") before: str = Field(default="", description="用于分页的游标。与after相反,填写它将获取前一页数据") + class AppOverview(BaseModel): id: str = Field("", description="应用ID") name: str = Field("", description="应用名称") @@ -312,14 +351,19 @@ class AppOverview(BaseModel): isPublished: Optional[bool] = Field(None, description="是否已发布") updateTime: Optional[int] = Field(None, description="更新时间。时间戳,单位秒") + class AppBuilderClientAppListResponse(BaseModel): request_id: str = Field("", description="请求ID") data: Optional[list[AppOverview]] = Field( [], description="应用概览列表") + class DescribeAppsRequest(BaseModel): - maxKeys: int = Field(default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) - marker: str = Field(default=None, description="用于分页的游标。marker 是应用的id,它定义了在列表中的位置。例如,如果你发出一个列表请求并收到 10个对象,以 app_id_123 开始,那么可以使用 marker=app_id_123 来获取列表的下一页数据") + maxKeys: int = Field( + default=10, description="当次查询的数据大小,默认10,最大值100", le=100, ge=1) + marker: str = Field( + default=None, description="用于分页的游标。marker 是应用的id,它定义了在列表中的位置。例如,如果你发出一个列表请求并收到 10个对象,以 app_id_123 开始,那么可以使用 marker=app_id_123 来获取列表的下一页数据") + class DescribeAppsResponse(BaseModel): requestId: str = Field("", description="请求ID") diff --git a/python/core/console/appbuilder_client/event_handler.py b/python/core/console/appbuilder_client/event_handler.py index e2bf40a39..843914043 100644 --- a/python/core/console/appbuilder_client/event_handler.py +++ b/python/core/console/appbuilder_client/event_handler.py @@ -21,13 +21,13 @@ class AppBuilderClientRunContext(object): def __init__(self) -> None: """ 初始化方法。 - + Args: 无参数。 - + Returns: None - + """ self.current_event = None self.current_tool_calls = None @@ -41,18 +41,21 @@ class AppBuilderEventHandler(object): def __init__(self): pass - def init(self, - appbuilder_client, - conversation_id, - query, - file_ids=None, - tools=None, - stream: bool = False, - event_handler=None, - **kwargs): + def init( + self, + appbuilder_client, + conversation_id, + query, + file_ids=None, + tools=None, + stream: bool = False, + event_handler=None, + action=None, + **kwargs + ): """ 初始化类实例并设置相关参数。 - + Args: appbuilder_client (object): AppBuilder客户端实例对象。 conversation_id (str): 对话ID。 @@ -61,11 +64,12 @@ def init(self, tools (list, optional): 工具列表,默认为None。 stream (bool, optional): 是否使用流式处理,默认为False。 event_handler (callable, optional): 事件处理函数,默认为None。 + action (object, optional): 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 **kwargs: 其他可选参数。 - + Returns: None - + """ self._appbuilder_client = appbuilder_client self._conversation_id = conversation_id @@ -78,19 +82,21 @@ def init(self, self._is_complete = False self._need_tool_call = False self._last_tool_output = None + self._action = action - self._iterator = self.__run_process__() if not self._stream else self.__stream_run_process__() + self._iterator = self.__run_process__( + ) if not self._stream else self.__stream_run_process__() def __run_process__(self): """ 运行进程,并在每次执行后生成结果。 - + Args: 无参数。 - + Returns: Generator: 生成器,每次执行后返回结果。 - + """ while not self._is_complete: if not self._need_tool_call: @@ -100,19 +106,19 @@ def __run_process__(self): res = self._submit_tool_output() self.__event_process__(res) yield res - - self.reset_state() + if self._need_tool_call and self._is_complete: + self.reset_state() def __event_process__(self, run_response): """ 处理事件响应。 - + Args: run_response (RunResponse): 运行时响应对象。 - + Returns: None - + Raises: ValueError: 当解析事件时发生异常或工具输出为空时。 """ @@ -120,9 +126,9 @@ def __event_process__(self, run_response): event = run_response.content.events[-1] except Exception as e: raise ValueError(e) - + event_status = event.status - + if event.status == 'success': self._is_complete = True elif event.status == 'interrupt': @@ -136,9 +142,11 @@ def __event_process__(self, run_response): "interrupt": self.interrupt, "success": self.success, } - + run_context = AppBuilderClientRunContext() self._update_run_context(run_context, run_response.content) + self.handle_event_type(run_context, run_response.content) + self.handle_content_type(run_context, run_response.content) if event_status in context_func_map: func = context_func_map[event_status] func_res = func(run_context, run_response.content) @@ -150,11 +158,13 @@ def __event_process__(self, run_response): else: if not isinstance(func_res[0], data_class.ToolOutput): try: - check_tool_output = data_class.ToolOutput(**func_res[0]) + check_tool_output = data_class.ToolOutput( + **func_res[0]) except Exception as e: - logger.error("func interrupt's output should be list[ToolOutput] or list[dict(can be trans to ToolOutput)]") + logger.error( + "func interrupt's output should be list[ToolOutput] or list[dict(can be trans to ToolOutput)]") raise ValueError(e) - self._last_tool_output =func_res + self._last_tool_output = func_res else: logger.warning( "Unknown status: {}, response data: {}".format(event_status, run_response)) @@ -162,13 +172,13 @@ def __event_process__(self, run_response): def __stream_run_process__(self): """ 流式运行处理函数 - + Args: 无参数。 - + Returns: Generator[Any, None, None]: 返回处理结果的生成器。 - + """ while not self._is_complete: if not self._need_tool_call: @@ -176,18 +186,18 @@ def __stream_run_process__(self): else: res = self._submit_tool_output() for msg in self.__stream_event_process__(res): - yield msg + yield msg def __stream_event_process__(self, run_response): """ 处理流事件,并调用对应的方法 - + Args: run_response: 包含流事件信息的响应对象 - + Returns: None - + Raises: ValueError: 当处理事件时发生异常或中断时工具输出为空时 """ @@ -198,9 +208,9 @@ def __stream_event_process__(self, run_response): event = msg.events[-1] except Exception as e: raise ValueError(e) - + event_status = event.status - + if event.status == 'success': self._is_complete = True elif event.status == 'interrupt': @@ -214,9 +224,11 @@ def __stream_event_process__(self, run_response): "interrupt": self.interrupt, "success": self.success, } - + run_context = AppBuilderClientRunContext() self._update_run_context(run_context, msg) + self.handle_event_type(run_context, msg) + self.handle_content_type(run_context, msg) if event_status in context_func_map: func = context_func_map[event_status] func_res = func(run_context, msg) @@ -228,32 +240,34 @@ def __stream_event_process__(self, run_response): else: if not isinstance(func_res[0], data_class.ToolOutput): try: - check_tool_output = data_class.ToolOutput(**func_res[0]) + check_tool_output = data_class.ToolOutput( + **func_res[0]) except Exception as e: - logger.info("func interrupt's output should be list[ToolOutput] or list[dict(can be trans to ToolOutput)]") + logger.info( + "func interrupt's output should be list[ToolOutput] or list[dict(can be trans to ToolOutput)]") raise ValueError(e) - self._last_tool_output =func_res + self._last_tool_output = func_res else: logger.warning( "Unknown status: {}, response data: {}".format(event_status, run_response)) - + yield msg def _update_run_context(self, run_context, run_response): """ 更新运行上下文。 - + Args: run_context (dict): 运行上下文字典。 run_response (object): 运行响应对象。 - + Returns: None - + """ run_context.current_event = run_response.events[-1] run_context.current_tool_calls = run_context.current_event.tool_calls - run_context.current_status = run_context.current_event.status + run_context.current_status = run_context.current_event.status run_context.need_tool_submit = run_context.current_status == 'interrupt' run_context.is_complete = run_context.current_status == 'success' try: @@ -270,7 +284,8 @@ def _run(self): query=self._query, file_ids=self._file_ids, stream=self._stream, - tools=self._tools + tools=self._tools, + action=self._action, ) return res @@ -297,19 +312,19 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb) -> None: if exc_type is not None: raise exc_val - + return def reset_state(self): """ 重置该对象的状态,将所有实例变量设置为默认值。 - + Args: 无 - + Returns: 无 - + """ self._appbuilder_client = None self._conversation_id = None @@ -324,20 +339,73 @@ def reset_state(self): self._need_tool_call = False self._iterator = None + def new_dialog( + self, + query=None, + file_ids=None, + tools=None, + action=None, + stream: bool = None, + event_handler=None, + **kwargs + ): + """ + 重置handler部分参数,用于复用该handler进行多轮对话。 + + Args: + query (str): 用户输入的查询语句。 + file_ids (list, optional): 文件ID列表,默认为None。 + tools (list, optional): 工具列表,默认为None。 + stream (bool, optional): 是否使用流式处理,默认为False。 + action (object, optional): 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + event_handler (callable, optional): 事件处理函数,默认为None。 + **kwargs: 其他可选参数。 + + Returns: + None + + """ + self._query = query or self._query + self._stream = stream or self._stream + + self._file_ids = file_ids + self._tools = tools + self._event_handler = event_handler + self._kwargs = kwargs + self._action = action + + # 重置部分状态 + self._is_complete = False + self._need_tool_call = False + self._last_tool_output = None + self._iterator = ( + self.__run_process__() + if not self._stream + else self.__stream_run_process__() + ) + def until_done(self): """ 迭代并遍历内部迭代器中的所有元素,直到迭代器耗尽。 - + Args: 无参数。 - + Returns: 无返回值。 - + """ for _ in self._iterator: pass + def handle_content_type(self, run_context, run_response): + # 用户可重载该方法,用于处理不同类型的content_type + pass + + def handle_event_type(self, run_context, run_response): + # 用户可重载该方法,用于处理不同类型的event_type + pass + def interrupt(self, run_context, run_response): # 用户可重载该方法,当event_status为interrupt时,会调用该方法 pass @@ -349,7 +417,7 @@ def preparing(self, run_context, run_response): def running(self, run_context, run_response): # 用户可重载该方法,当event_status为running时,会调用该方法 pass - + def error(self, run_context, run_response): # 用户可重载该方法,当event_status为error时,会调用该方法 pass diff --git a/python/tests/test_appbuilder_client_chatflow.py b/python/tests/test_appbuilder_client_chatflow.py new file mode 100644 index 000000000..c62a4d60f --- /dev/null +++ b/python/tests/test_appbuilder_client_chatflow.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import appbuilder +from appbuilder.core.console.appbuilder_client import data_class + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientChatflow(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "4403205e-fb83-4fac-96d8-943bdb63796f" + + def test_appbuilder_run_chatflow(self): + # 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败 + """ + 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败 + + Args: + self (unittest.TestCase): unittest的TestCase对象 + + Raises: + None: 如果app_id不为空,则不会引发任何异常 + unittest.SkipTest (optional): 如果app_id为空,则跳过单测执行 + """ + if len(self.app_id) == 0: + self.skipTest("self.app_id is empty") + appbuilder.logger.setLevel("ERROR") + interrupt_ids = [] + builder = appbuilder.AppBuilderClient(self.app_id) + conversation_id = builder.create_conversation() + msg = builder.run(conversation_id, "查天气", stream=True) + + interrupt_event_id = None + for ans in msg.content: + for event in ans.events: + if event.content_type == "chatflow_interrupt": + assert event.event_type == "chatflow" + interrupt_event_id = event.detail.get("interrupt_event_id") + break + self.assertIsNotNone(interrupt_event_id) + interrupt_ids.append(interrupt_event_id) + + msg = builder.run( + conversation_id, + "查航班", + stream=True, + action=data_class.Action.create_resume_action(interrupt_event_id), + ) + interrupt_event_id = None + for ans in msg.content: + for event in ans.events: + if event.content_type == "chatflow_interrupt": + assert event.event_type == "chatflow" + interrupt_event_id = event.detail.get("interrupt_event_id") + break + self.assertIsNotNone(interrupt_event_id) + interrupt_ids.append(interrupt_event_id) + + msg2 = builder.run(conversation_id=conversation_id, + query="CA1234", stream=True, + action=data_class.Action.create_resume_action(interrupt_ids.pop())) + interrupt_event_id = None + for ans in msg2.content: + for event in ans.events: + if event.content_type == "chatflow_interrupt": + assert event.event_type == "chatflow" + interrupt_event_id = event.detail.get("interrupt_event_id") + break + self.assertIsNotNone(interrupt_event_id) + interrupt_ids.append(interrupt_event_id) + + msg2 = builder.run(conversation_id=conversation_id, + query="北京的", stream=True, + action=data_class.Action.create_resume_action(interrupt_ids.pop())) + has_multiple_dialog_event = False + for ans in msg2.content: + for event in ans.events: + if event.content_type == "multiple_dialog_event": + assert event.event_type == "chatflow" + has_multiple_dialog_event = True + break + self.assertTrue(has_multiple_dialog_event) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_appbuilder_client_chatflow_event_handler.py b/python/tests/test_appbuilder_client_chatflow_event_handler.py new file mode 100644 index 000000000..db82da1ee --- /dev/null +++ b/python/tests/test_appbuilder_client_chatflow_event_handler.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import appbuilder +from appbuilder.core.console.appbuilder_client.event_handler import ( + AppBuilderEventHandler, +) + + +class MyEventHandler(AppBuilderEventHandler): + def __init__(self): + super().__init__() + self.interrupt_ids = [] + + def handle_content_type(self, run_context, run_response): + interrupt_event_id = None + event = run_response.events[-1] + if event.content_type == "chatflow_interrupt": + interrupt_event_id = event.detail.get("interrupt_event_id") + if interrupt_event_id is not None: + self.interrupt_ids.append(interrupt_event_id) + + def _create_action(self): + if len(self.interrupt_ids) == 0: + return None + event_id = self.interrupt_ids.pop() + return { + "action_type": "resume", + "parameters": {"interrupt_event": {"id": event_id, "type": "chat"}}, + } + + def run(self, query=None): + super().new_dialog( + query=query, + action=self._create_action(), + ) + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientChatflow(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "4403205e-fb83-4fac-96d8-943bdb63796f" + + def test_appbuilder_client_run_with_handler_stream(self): + if len(self.app_id) == 0: + self.skipTest("self.app_id is empty") + appbuilder.logger.setLevel("ERROR") + builder = appbuilder.AppBuilderClient(self.app_id) + conversation_id = builder.create_conversation() + + event_handler = MyEventHandler() + event_handler.init(appbuilder_client=builder, + conversation_id=conversation_id, stream=True, query="查天气") + for data in event_handler: + pass + event_handler.run( + query="查航班", + ) + for data in event_handler: + pass + event_handler.run( + query="CA1234", + ) + for data in event_handler: + pass + event_handler.run( + query="北京的", + ) + for data in event_handler: + pass + + def test_appbuilder_client_run_with_handler(self): + if len(self.app_id) == 0: + self.skipTest("self.app_id is empty") + appbuilder.logger.setLevel("ERROR") + builder = appbuilder.AppBuilderClient(self.app_id) + conversation_id = builder.create_conversation() + + event_handler = MyEventHandler() + event_handler.init( + appbuilder_client=builder, + conversation_id=conversation_id, + stream=False, + query="查天气", + ) + for data in event_handler: + pass + event_handler.run( + query="查航班", + ) + for data in event_handler: + pass + event_handler.run( + query="CA1234", + ) + for data in event_handler: + pass + event_handler.run( + query="北京的", + ) + for data in event_handler: + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_appbuilder_client_chatflow_event_handler_v2.py b/python/tests/test_appbuilder_client_chatflow_event_handler_v2.py new file mode 100644 index 000000000..3ca9cb52b --- /dev/null +++ b/python/tests/test_appbuilder_client_chatflow_event_handler_v2.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import appbuilder +from appbuilder.core.console.appbuilder_client.event_handler import ( + AppBuilderEventHandler, +) + + +class MyEventHandler(AppBuilderEventHandler): + def __init__(self): + super().__init__() + self.interrupt_ids = [] + + def handle_content_type(self, run_context, run_response): + interrupt_event_id = None + event = run_response.events[-1] + if event.content_type == "chatflow_interrupt": + interrupt_event_id = event.detail.get("interrupt_event_id") + if interrupt_event_id is not None: + self.interrupt_ids.append(interrupt_event_id) + + def _create_action(self): + if len(self.interrupt_ids) == 0: + return None + event_id = self.interrupt_ids.pop() + return { + "action_type": "resume", + "parameters": {"interrupt_event": {"id": event_id, "type": "chat"}}, + } + + def gen_action(self): + while True: + yield self._create_action() + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientChatflow(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "4403205e-fb83-4fac-96d8-943bdb63796f" + + def test_appbuilder_client_run_with_handler_multiple_dialog(self): + if len(self.app_id) == 0: + self.skipTest("self.app_id is empty") + appbuilder.logger.setLevel("DEBUG") + builder = appbuilder.AppBuilderClient(self.app_id) + conversation_id = builder.create_conversation() + + queries = ["查天气", "查航班", "CA1234", "北京的"] + event_handler = MyEventHandler() + event_handler = builder.run_multiple_dialog_with_handler( + conversation_id=conversation_id, + queries=queries, + event_handler=event_handler, + stream=True, + actions=event_handler.gen_action(), + ) + for data in event_handler: + for ans in data: + pass + + +if __name__ == "__main__": + unittest.main()