diff --git a/go/appbuilder/component_client.go b/go/appbuilder/component_client.go index e18cb2f1..4ab5d2aa 100644 --- a/go/appbuilder/component_client.go +++ b/go/appbuilder/component_client.go @@ -41,7 +41,7 @@ func NewComponentClient(config *SDKConfig) (*ComponentClient, error) { return &ComponentClient{sdkConfig: config, client: client}, nil } -func (t *ComponentClient) Run(component, version, action string, req ComponentRunRequest) (ComponentClientIterator, error) { +func (t *ComponentClient) Run(component, version, action string, stream bool, parameters map[string]any) (ComponentClientIterator, error) { request := http.Request{} urlSuffix := fmt.Sprintf("/components/%s", component) @@ -66,6 +66,11 @@ func (t *ComponentClient) Run(component, version, action string, req ComponentRu request.Method = "POST" header.Set("Content-Type", "application/json") request.Header = header + + req := ComponentRunRequest{ + Stream: stream, + Parameters: parameters, + } data, _ := json.Marshal(req) request.Body = NopCloser(bytes.NewReader(data)) request.ContentLength = int64(len(data)) // 手动设置长度 diff --git a/go/appbuilder/component_client_data.go b/go/appbuilder/component_client_data.go index 8215c336..f760a4f3 100644 --- a/go/appbuilder/component_client_data.go +++ b/go/appbuilder/component_client_data.go @@ -47,18 +47,14 @@ type ComponentRunResponse struct { } type ComponentRunResponseData struct { - ConversationID string `json:"conversation_id"` - MessageID string `json:"message_id"` - TraceID string `json:"trace_id"` - UserID string `json:"user_id"` - EndUserID string `json:"end_user_id"` - IsCompletion bool `json:"is_completion"` - ComponentOutput ComponentOutput `json:"component_output"` -} - -type ComponentOutput struct { - Role string `json:"role"` - Content []Content `json:"content"` + ConversationID string `json:"conversation_id"` + MessageID string `json:"message_id"` + TraceID string `json:"trace_id"` + UserID string `json:"user_id"` + EndUserID string `json:"end_user_id"` + IsCompletion bool `json:"is_completion"` + Role string `json:"role"` + Content []Content `json:"content"` } type Content struct { diff --git a/go/appbuilder/component_client_test.go b/go/appbuilder/component_client_test.go new file mode 100644 index 00000000..aad7b2fa --- /dev/null +++ b/go/appbuilder/component_client_test.go @@ -0,0 +1,89 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package appbuilder + +import ( + "bytes" + "fmt" + "os" + "testing" +) + +func TestComponentClient(t *testing.T) { + var logBuffer bytes.Buffer + + // 设置环境变量 + os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG") + os.Setenv("APPBUILDER_LOGFILE", "") + + // 测试逻辑 + config, err := NewSDKConfig("", "") + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new http client config failed: %v", err) + } + + componentID := "44205c67-3980-41f7-aad4-37357b577fd0" + client, err := NewComponentClient(config) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new ComponentClient instance failed") + } + + parameters := map[string]any{ + SysOriginQuery: "北京景点推荐", + } + i, err := client.Run(componentID, "latest", "", false, parameters) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: %v", err) + } + + // test result + for answer, err := i.Next(); err == nil; answer, err = i.Next() { + data := answer.Content[0].Text + if data == nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: data is nil") + } + } + + i2, err := client.Run(componentID, "latest", "", true, parameters) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: %v", err) + } + + // test stream result + var answerText any + for answer, err := i2.Next(); err == nil; answer, err = i2.Next() { + if len(answer.Content) == 0 { + continue + } + answerText = answer.Content[0].Text + } + if answerText == nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: data is nil") + } + + // 如果测试失败,则输出缓冲区中的日志 + if t.Failed() { + fmt.Println(logBuffer.String()) + } else { // else 紧跟在右大括号后面 + // 测试通过,打印文件名和测试函数名 + t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java b/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java index 4993fe70..dcb0b8ff 100644 --- a/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java +++ b/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java @@ -1,7 +1,9 @@ package com.baidubce.appbuilder.console.componentclient; import java.io.IOException; +import java.util.HashMap; import java.util.Iterator; +import java.util.Map; import java.nio.charset.StandardCharsets; import org.apache.hc.core5.http.ClassicHttpRequest; @@ -13,7 +15,6 @@ import com.baidubce.appbuilder.base.utils.http.HttpResponse; import com.baidubce.appbuilder.base.utils.json.JsonUtils; import com.baidubce.appbuilder.model.componentclient.ComponentClientIterator; -import com.baidubce.appbuilder.model.componentclient.ComponentClientRunRequest; import com.baidubce.appbuilder.model.componentclient.ComponentClientRunResponse; public class ComponentClient extends Component { @@ -29,18 +30,23 @@ public ComponentClient(String secretKey, String gateway) { super(secretKey, gateway); } - /** - * 执行应用构建客户端运行请求 + /** + * 运行Component,根据输入的问题、会话ID、文件ID数组以及是否以流模式等信息返回结果,返回ComponentClientIterator迭代器。 * - * @param requestBody 请求体,包含运行所需的所有信息 - * @return 返回包含构建客户端响应的迭代器 - * @throws IOException 如果在执行请求时发生I/O错误 - * @throws AppBuilderServerException 如果应用构建服务器返回错误响应 + * + * @param componentId 组件ID + * @param version 组件版本 + * @param action 参数动作 + * @param stream 是否以流的形式返回结果 + * @param parameters 参数列表 + * @return ComponentCientIterator 迭代器,包含 ComponentCientIterator 的运行结果 + * @throws IOException 如果在 I/O 操作过程中发生错误 + * @throws AppBuilderServerException 如果 AppBuilder 服务器返回错误 */ - public ComponentClientIterator run(String component, String version, String action, ComponentClientRunRequest requestBody) + public ComponentClientIterator run(String componentId, String version, String action, boolean stream, Map parameters) throws IOException, AppBuilderServerException { String url = AppBuilderConfig.COMPONENT_RUN_URL; - String urlSuffix = String.format("%s/%s", url, component); + String urlSuffix = String.format("%s/%s", url, componentId); if (!version.isEmpty()) { urlSuffix += String.format("/version/%s", version); } @@ -52,7 +58,11 @@ public ComponentClientIterator run(String component, String version, String acti } } + Map requestBody = new HashMap<>(); + requestBody.put("parameters", parameters); + requestBody.put("stream", stream); String jsonBody = JsonUtils.serialize(requestBody); + ClassicHttpRequest postRequest = httpClient.createPostRequestV2(urlSuffix, new StringEntity(jsonBody, StandardCharsets.UTF_8)); postRequest.setHeader("Content-Type", "application/json"); diff --git a/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java index 3c08df96..dacd5727 100644 --- a/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java +++ b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java @@ -58,8 +58,8 @@ public static class ComponentRunResponseData { private String endUserID; @SerializedName("is_completion") private boolean isCompletion; - @SerializedName("component_output") - private ComponentOutput componentOutput; + private String role; + private Content[] content; public String getConversationID() { return conversationID; @@ -109,168 +109,155 @@ public void setCompletion(boolean completion) { isCompletion = completion; } - public ComponentOutput getComponentOutput() { - return componentOutput; + public String getRole() { + return role; } - public void setComponentOutput(ComponentOutput componentOutput) { - this.componentOutput = componentOutput; + public void setRole(String role) { + this.role = role; } - public static class ComponentOutput { - private String role; - private Content content; + public Content[] getContent() { + return content; + } + + public void setContent(Content[] content) { + this.content = content; + } - public String getRole() { - return role; + public static class Content { + private String name; + @SerializedName("visible_scope") + private String visibleScope; + @SerializedName("raw_data") + private Map rawData = new HashMap<>(); + private Map usage = new HashMap<>(); + private Map metrics = new HashMap<>(); + private String type; + private Map text = new HashMap<>(); + private ComponentEvent event; + + public String getName() { + return name; } - public void setRole(String role) { - this.role = role; + public void setName(String name) { + this.name = name; } - public Content getContent() { - return content; + public String getVisibleScope() { + return visibleScope; } - public void setContent(Content content) { - this.content = content; + public void setVisibleScope(String visibleScope) { + this.visibleScope = visibleScope; } - public static class Content { - private String name; - @SerializedName("visible_scope") - private String visibleScope; - @SerializedName("raw_data") - private Map rawData = new HashMap<>(); - private Map usage = new HashMap<>(); - private Map metrics = new HashMap<>(); - private String type; - private Map text = new HashMap<>(); - private ComponentEvent event; + public Map getRawData() { + return rawData; + } - public String getName() { - return name; - } + public void setRawData(Map rawData) { + this.rawData = rawData; + } - public void setName(String name) { - this.name = name; - } + public Map getUsage() { + return usage; + } - public String getVisibleScope() { - return visibleScope; - } + public void setUsage(Map usage) { + this.usage = usage; + } - public void setVisibleScope(String visibleScope) { - this.visibleScope = visibleScope; - } + public Map getMetrics() { + return metrics; + } - public Map getRawData() { - return rawData; - } + public void setMetrics(Map metrics) { + this.metrics = metrics; + } - public void setRawData(Map rawData) { - this.rawData = rawData; + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Map getText() { + return text; + } + + public void setText(Map text) { + this.text = text; + } + + public ComponentEvent getEvent() { + return event; + } + + public void setEvent(ComponentEvent event) { + this.event = event; + } + + public static class ComponentEvent { + private String id; + private String status; + private String name; + @SerializedName("created_time") + private String createdTime; + @SerializedName("error_code") + private String errorCode; + @SerializedName("error_message") + private String errorMessage; + + public String getId() { + return id; } - public Map getUsage() { - return usage; + public void setId(String id) { + this.id = id; } - public void setUsage(Map usage) { - this.usage = usage; + public String getStatus() { + return status; } - public Map getMetrics() { - return metrics; + public void setStatus(String status) { + this.status = status; } - public void setMetrics(Map metrics) { - this.metrics = metrics; + public String getName() { + return name; } - public String getType() { - return type; + public void setName(String name) { + this.name = name; } - public void setType(String type) { - this.type = type; + public String getCreatedTime() { + return createdTime; } - public Map getText() { - return text; + public void setCreatedTime(String createdTime) { + this.createdTime = createdTime; } - public void setText(Map text) { - this.text = text; + public String getErrorCode() { + return errorCode; } - public ComponentEvent getEvent() { - return event; + public void setErrorCode(String errorCode) { + this.errorCode = errorCode; } - public void setEvent(ComponentEvent event) { - this.event = event; + public String getErrorMessage() { + return errorMessage; } - public static class ComponentEvent { - private String id; - private String status; - private String name; - @SerializedName("created_time") - private String createdTime; - @SerializedName("error_code") - private String errorCode; - @SerializedName("error_message") - private String errorMessage; - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public String getStatus() { - return status; - } - - public void setStatus(String status) { - this.status = status; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getCreatedTime() { - return createdTime; - } - - public void setCreatedTime(String createdTime) { - this.createdTime = createdTime; - } - - public String getErrorCode() { - return errorCode; - } - - public void setErrorCode(String errorCode) { - this.errorCode = errorCode; - } - - public String getErrorMessage() { - return errorMessage; - } - - public void setErrorMessage(String errorMessage) { - this.errorMessage = errorMessage; - } + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; } } } diff --git a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java index f4be4037..958cec6b 100644 --- a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java @@ -1,9 +1,5 @@ package com.baidubce.appbuilder; -import com.baidubce.appbuilder.base.exception.AppBuilderServerException; -import com.baidubce.appbuilder.console.appbuilderclient.AppBuilderClient; -import com.baidubce.appbuilder.console.appbuilderclient.AppList; - import java.io.IOException; import java.nio.file.Paths; import java.nio.file.Files; @@ -11,16 +7,20 @@ import java.util.Map; import java.util.Stack; import java.util.List; +import org.junit.Before; +import org.junit.Test; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientIterator; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult; +import com.baidubce.appbuilder.base.exception.AppBuilderServerException; +import com.baidubce.appbuilder.console.appbuilderclient.AppBuilderClient; +import com.baidubce.appbuilder.console.appbuilderclient.AppList; import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest; import com.baidubce.appbuilder.model.appbuilderclient.AppsDescribeRequest; import com.baidubce.appbuilder.model.appbuilderclient.Event; import com.baidubce.appbuilder.model.appbuilderclient.EventContent; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientRunRequest; -import org.junit.Before; -import org.junit.Test; + import static org.junit.Assert.*; diff --git a/java/src/test/java/com/baidubce/appbuilder/ComponentClientTest.java b/java/src/test/java/com/baidubce/appbuilder/ComponentClientTest.java new file mode 100644 index 00000000..2971cc55 --- /dev/null +++ b/java/src/test/java/com/baidubce/appbuilder/ComponentClientTest.java @@ -0,0 +1,54 @@ +package com.baidubce.appbuilder; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import static org.junit.Assert.*; + +import com.baidubce.appbuilder.base.exception.AppBuilderServerException; +import com.baidubce.appbuilder.console.componentclient.ComponentClient; +import com.baidubce.appbuilder.model.componentclient.ComponentClientIterator; +import com.baidubce.appbuilder.model.componentclient.ComponentClientRunRequest; +import com.baidubce.appbuilder.model.componentclient.ComponentClientRunResponse; + +public class ComponentClientTest { + String componentId; + + @Before + public void setUp() { + System.setProperty("APPBUILDER_TOKEN", System.getenv("APPBUILDER_TOKEN")); + System.setProperty("APPBUILDER_LOGLEVEL", "DEBUG"); + componentId = "44205c67-3980-41f7-aad4-37357b577fd0"; + } + + @Test + public void TestComponentClientRun() throws IOException, AppBuilderServerException { + ComponentClient client = new ComponentClient(); + Map parameters = new HashMap<>(); + parameters.put(ComponentClientRunRequest.SysOriginQuery, "北京景点推荐"); + ComponentClientIterator iter = client.run(componentId, "latest", "", false, parameters); + while (iter.hasNext()) { + ComponentClientRunResponse.ComponentRunResponseData response = iter.next(); + assertNotNull(response.getContent()[0].getText()); + } + } + + @Test + public void TestComponentClientRunStream() throws IOException, AppBuilderServerException { + ComponentClient client = new ComponentClient(); + Map parameters = new HashMap<>(); + parameters.put(ComponentClientRunRequest.SysOriginQuery, "北京景点推荐"); + ComponentClientIterator iter = client.run(componentId, "latest", "", true, parameters); + Object text = null; + while (iter.hasNext()) { + ComponentClientRunResponse.ComponentRunResponseData response = iter.next(); + if (response.getContent().length > 0) { + text = response.getContent()[0].getText(); + } + } + assertNotNull(text); + } +} diff --git a/python/core/console/component_client/component_client.py b/python/core/console/component_client/component_client.py index a033cffe..7f1d93b3 100644 --- a/python/core/console/component_client/component_client.py +++ b/python/core/console/component_client/component_client.py @@ -18,7 +18,7 @@ from appbuilder.core.console.component_client import data_class from appbuilder.core._exception import AppBuilderServerException from appbuilder.utils.logger_util import logger -from appbuilder.utils.trace.tracer_wrapper import client_tool_trace +from appbuilder.utils.trace.tracer_wrapper import client_run_trace from appbuilder.utils.sse_util import SSEClient @@ -31,10 +31,10 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) - @client_tool_trace + @client_run_trace def run( self, - component: str, + component_id: str, sys_origin_query: str, version: str = None, action: str = None, @@ -44,11 +44,26 @@ def run( sys_end_user_id: str = None, sys_chat_history: list = None, **kwargs, - ) -> data_class.RunResponse: + ) -> Message: + """ 组件运行 + Args: + component_id (str): 组件ID + sys_origin_query (str): 用户输入的原始查询语句 + version (str): 组件版本号 + action (str): 组件动作 + stream (bool): 是否流式返回 + sys_file_urls (dict): 文件地址 + sys_conversation_id (str): 会话ID + sys_end_user_id (str): 用户ID + sys_chat_history (list): 聊天 + kwargs: 其他参数 + Returns: + message (Message): 对话结果,一个Message对象,使用message.content获取内容。 + """ headers = self.http_client.auth_header_v2() headers["Content-Type"] = "application/json" - url_suffix = f"/components/{component}" + url_suffix = f"/components/{component_id}" if version is not None: url_suffix += f"/version/{version}" if action is not None: diff --git a/python/tests/test_component_client.py b/python/tests/test_component_client.py new file mode 100644 index 00000000..62094bb7 --- /dev/null +++ b/python/tests/test_component_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import appbuilder +import os + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestComponentCLient(unittest.TestCase): + def test_component_client(self): + appbuilder.logger.setLoglevel("DEBUG") + client = appbuilder.ComponentClient() + + res = client.run(component="44205c67-3980-41f7-aad4-37357b577fd0", + version="latest", sys_origin_query="北京景点推荐") + print(res.content.content) + + def test_component_client_stream(self): + appbuilder.logger.setLoglevel("DEBUG") + client = appbuilder.ComponentClient() + + res = client.run(component="44205c67-3980-41f7-aad4-37357b577fd0", + version="latest", sys_origin_query="北京景点推荐", stream=True) + for data in res.content: + print(data) + + +if __name__ == "__main__": + unittest.main()