Skip to content

Commit

Permalink
chatflow工作流应用对话验证,增加'信息收集节点'回复功能 (baidubce#601)
Browse files Browse the repository at this point in the history
* chatflow工作流应用对话验证,增加'信息收集节点'回复功能
  • Loading branch information
userpj authored Nov 20, 2024
1 parent 5e9d83f commit 5d33bdf
Show file tree
Hide file tree
Showing 14 changed files with 947 additions and 128 deletions.
70 changes: 55 additions & 15 deletions go/appbuilder/app_builder_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"strconv"
"time"
)
Expand Down Expand Up @@ -250,51 +251,90 @@ func (t *AppBuilderClient) UploadLocalFile(conversationID string, filePath strin
return val.(string), nil
}

func (t *AppBuilderClient) Run(conversationID string, query string, fileIDS []string, stream bool) (AppBuilderClientIterator, error) {
if len(conversationID) == 0 {
return nil, errors.New("conversationID mustn't be empty")
func (t *AppBuilderClient) Run(param ...interface{}) (AppBuilderClientIterator, error) {
if len(param) == 0 {
return nil, errors.New("no arguments provided")
}
var err error
var req AppBuilderClientRunRequest

if reflect.TypeOf(param[0]) == reflect.TypeOf(AppBuilderClientRunRequest{}) {
req = param[0].(AppBuilderClientRunRequest)
} else {
req, err = t.buildAppBuilderClientRunRequest(param...)
if err != nil {
return nil, err
}
}
m := map[string]any{
"app_id": t.appID,
"conversation_id": conversationID,
"query": query,
"file_ids": fileIDS,
"stream": stream,

if len(req.ConversationID) == 0 {
return nil, errors.New("conversationID mustn't be empty")
}

request := http.Request{}

serviceURL, err := t.sdkConfig.ServiceURLV2("/app/conversation/runs")
if err != nil {
return nil, err
}

header := t.sdkConfig.AuthHeaderV2()
request.URL = serviceURL
request.Method = "POST"
header := t.sdkConfig.AuthHeaderV2()
header.Set("Content-Type", "application/json")
request.Header = header
data, _ := json.Marshal(m)
data, _ := json.Marshal(req)
request.Body = NopCloser(bytes.NewReader(data))
request.ContentLength = int64(len(data)) // 手动设置长度
request.ContentLength = int64(len(data))

t.sdkConfig.BuildCurlCommand(&request)

resp, err := t.client.Do(&request)
if err != nil {
return nil, err
}

requestID, err := checkHTTPResponse(resp)
if err != nil {
return nil, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}
r := NewSSEReader(1024*1024, bufio.NewReader(resp.Body))
if stream {
if req.Stream {
return &AppBuilderClientStreamIterator{requestID: requestID, r: r, body: resp.Body}, nil
}
return &AppBuilderClientOnceIterator{body: resp.Body}, nil
}

func (t *AppBuilderClient) buildAppBuilderClientRunRequest(param ...interface{}) (AppBuilderClientRunRequest, error) {
conversationID, ok := param[0].(string)
if !ok {
return AppBuilderClientRunRequest{}, errors.New("conversationID must be string type")
}
query, ok := param[1].(string)
if !ok {
return AppBuilderClientRunRequest{}, errors.New("query must be string type")
}

var fileIDS []string
if param[2] != nil {
fileIDS, ok = param[2].([]string)
if !ok {
fileIDS = nil
}
}

stream, ok := param[3].(bool)
if !ok {
stream = false
}

return AppBuilderClientRunRequest{
AppID: t.appID,
ConversationID: conversationID,
Query: query,
Stream: stream,
FileIDs: fileIDS,
}, nil
}

func (t *AppBuilderClient) RunWithToolCall(req AppBuilderClientRunRequest) (AppBuilderClientIterator, error) {
if len(req.ConversationID) == 0 {
return nil, errors.New("conversationID mustn't be empty")
Expand Down
84 changes: 67 additions & 17 deletions go/appbuilder/app_builder_client_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,33 @@ import (
)

const (
CodeContentType = "code"
TextContentType = "text"
ImageContentType = "image"
RAGContentType = "rag"
FunctionCallContentType = "function_call"
AudioContentType = "audio"
VideoContentType = "video"
StatusContentType = "status"
CodeContentType = "code"
TextContentType = "text"
ImageContentType = "image"
RAGContentType = "rag"
FunctionCallContentType = "function_call"
AudioContentType = "audio"
VideoContentType = "video"
StatusContentType = "status"
ChatflowInterruptContentType = "chatflow_interrupt"
PublishMessageContentType = "publish_message"
)

const (
ChatflowEventType = "chatflow"
)

var TypeToStruct = map[string]reflect.Type{
CodeContentType: reflect.TypeOf(CodeDetail{}),
TextContentType: reflect.TypeOf(TextDetail{}),
ImageContentType: reflect.TypeOf(ImageDetail{}),
RAGContentType: reflect.TypeOf(RAGDetail{}),
FunctionCallContentType: reflect.TypeOf(FunctionCallDetail{}),
AudioContentType: reflect.TypeOf(AudioDetail{}),
VideoContentType: reflect.TypeOf(VideoDetail{}),
StatusContentType: reflect.TypeOf(StatusDetail{}),
CodeContentType: reflect.TypeOf(CodeDetail{}),
TextContentType: reflect.TypeOf(TextDetail{}),
ImageContentType: reflect.TypeOf(ImageDetail{}),
RAGContentType: reflect.TypeOf(RAGDetail{}),
FunctionCallContentType: reflect.TypeOf(FunctionCallDetail{}),
AudioContentType: reflect.TypeOf(AudioDetail{}),
VideoContentType: reflect.TypeOf(VideoDetail{}),
StatusContentType: reflect.TypeOf(StatusDetail{}),
ChatflowInterruptContentType: reflect.TypeOf(ChatflowInterruptDetail{}),
PublishMessageContentType: reflect.TypeOf(PublishMessageDetail{}),
}

type AppBuilderClientRunRequest struct {
Expand All @@ -50,9 +58,11 @@ type AppBuilderClientRunRequest struct {
Stream bool `json:"stream"`
EndUserID *string `json:"end_user_id"`
ConversationID string `json:"conversation_id"`
FileIDs []string `json:"file_ids"`
Tools []Tool `json:"tools"`
ToolOutputs []ToolOutput `json:"tool_outputs"`
ToolChoice *ToolChoice `json:"tool_choice"`
Action *Action `json:"action"`
}

type Tool struct {
Expand Down Expand Up @@ -81,6 +91,36 @@ type ToolChoiceFunction struct {
Input map[string]interface{} `json:"input"`
}

type Action struct {
ActionType string `json:"action_type"`
Paramters *ActionParamters `json:"parameters"`
}

type ActionParamters struct {
InterruptEvent *ActionInterruptEvent `json:"interrupt_event"`
}

type ActionInterruptEvent struct {
ID string `json:"id"`
Type string `json:"type"`
}

func NewResumeAction(eventId string) *Action {
return NewAction("resume", eventId, "chat")
}

func NewAction(actionType string, eventId string, eventType string) *Action {
return &Action{
ActionType: actionType,
Paramters: &ActionParamters{
InterruptEvent: &ActionInterruptEvent{
ID: eventId,
Type: eventType,
},
},
}
}

type AgentBuilderRawResponse struct {
RequestID string `json:"request_id"`
Date string `json:"date"`
Expand Down Expand Up @@ -122,7 +162,7 @@ type Event struct {
EventType string
ContentType string
Usage Usage
Detail any // 将any替换为interface{}
Detail any
ToolCalls []ToolCall
}

Expand Down Expand Up @@ -185,6 +225,16 @@ type VideoDetail struct {

type StatusDetail struct{}

type ChatflowInterruptDetail struct {
InterruptEventID string `json:"interrupt_event_id"`
InterruptEventType string `json:"interrupt_event_type"`
}

type PublishMessageDetail struct {
Message string `json:"message"`
MessageID string `json:"message_id"`
}

type DefaultDetail struct {
URLS []string `json:"urls"`
Files []string `json:"files"`
Expand Down
102 changes: 101 additions & 1 deletion go/appbuilder/app_builder_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ func TestAppBuilderClientRunToolChoice(t *testing.T) {
input := make(map[string]any)
input["city"] = "北京"
end_user_id := "go_user_id_0"
i, err := client.RunWithToolCall(AppBuilderClientRunRequest{
i, err := client.Run(AppBuilderClientRunRequest{
ConversationID: conversationID,
AppID: appID,
Query: "你能干什么",
Expand Down Expand Up @@ -610,3 +610,103 @@ func TestAppBuilderClientRunToolChoice(t *testing.T) {
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}

func TestAppBuilderClientRunChatflow(t *testing.T) {
var logBuffer bytes.Buffer

os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG")

config, err := NewSDKConfig("", "")
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new http client config failed: %v", err)
}

appID := "4403205e-fb83-4fac-96d8-943bdb63796f"
client, err := NewAppBuilderClient(appID, config)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new AgentBuidler instance failed")
}

conversationID, err := client.CreateConversation()
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("create conversation failed: %v", err)
}

i, err := client.Run(AppBuilderClientRunRequest{
ConversationID: conversationID,
AppID: appID,
Query: "查天气",
Stream: true,
})

if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("run failed:%v", err)
}

var interruptId string
for answer, err := i.Next(); err == nil; answer, err = i.Next() {
for _, ev := range answer.Events {
if ev.ContentType == ChatflowInterruptContentType {
if ev.EventType != ChatflowEventType {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("event type error:%v", err)
}

deatil := ev.Detail.(ChatflowInterruptDetail)
interruptId = deatil.InterruptEventID
break
}
}
}

if len(interruptId) == 0 {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("interrupt id is empty")
}

i2, err := client.Run(AppBuilderClientRunRequest{
ConversationID: conversationID,
AppID: appID,
Query: "我先查个航班动态",
Stream: true,
Action: NewResumeAction(interruptId),
})
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("run failed:%v", err)
}

var message string
for answer, err := i2.Next(); err == nil; answer, err = i2.Next() {
for _, ev := range answer.Events {
if ev.ContentType == PublishMessageContentType {
if ev.EventType != ChatflowEventType {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("event type error:%v", err)
}

detail := ev.Detail.(PublishMessageDetail)
message = detail.Message
break
}
}
}
if len(message) == 0 {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("message is empty")
}
fmt.Println(message)

// 如果测试失败,则输出缓冲区中的日志
if t.Failed() {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
fmt.Println(logBuffer.String())
} else { // else 紧跟在右大括号后面
// 测试通过,打印文件名和测试函数名
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}
Loading

0 comments on commit 5d33bdf

Please sign in to comment.