Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
yepeiwen01 committed Dec 23, 2024
2 parents 032622d + d55e6fc commit 9d591e1
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 64 deletions.
6 changes: 3 additions & 3 deletions python/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, secret_key=None, gateway="", gateway_v2=""):
self.session = AsyncInnerSession()

@staticmethod
def check_response_header(response: ClientResponse):
async def check_response_header(response: ClientResponse):
r"""check_response_header is a helper method for check head status .
:param response: requests.Response.
:rtype:
Expand All @@ -252,7 +252,7 @@ def check_response_header(response: ClientResponse):
if status_code == requests.codes.ok:
return
message = "request_id={} , http status code is {}, body is {}".format(
__class__.response_request_id(response), status_code, response.text
await __class__.response_request_id(response), status_code, await response.text()
)
if status_code == requests.codes.bad_request:
raise BadRequestException(message)
Expand All @@ -268,7 +268,7 @@ def check_response_header(response: ClientResponse):
raise BaseRPCException(message)

@staticmethod
def response_request_id(response: ClientResponse):
async def response_request_id(response: ClientResponse):
r"""response_request_id is a helper method to get the unique request id"""
return response.headers.get("X-Appbuilder-Request-Id", "")

Expand Down
4 changes: 0 additions & 4 deletions python/core/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,18 @@ async def build_curl(self, method, url, data=None, json_data=None, **kwargs) ->

return curl

@session_post
async def post(self, url, data=None, json=None, **kwargs):
logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_POST, url, data=data, json_data=json, **kwargs) + "\n")
return await super().post(url=url, data=data, json=json, **kwargs)

@session_post
async def delete(self, url, **kwargs):
logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_DELETE, url, **kwargs) + "\n")
return await super().delete(url=url, **kwargs)

@session_post
async def get(self, url, **kwargs):
logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_GET, url, **kwargs) + "\n")
return await super().get(url=url, **kwargs)

@session_post
async def put(self, url, data=None, **kwargs):
logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_PUT, url, data=data, **kwargs) + "\n")
return await super().put(url=url, data=data, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def create_conversation(self) -> str:
response = await self.http_client.session.post(
url, headers=headers, json={"app_id": self.app_id}, timeout=None
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.CreateConversationResponse(**data)
return resp.conversation_id
Expand Down Expand Up @@ -116,8 +116,8 @@ async def run(
response = await self.http_client.session.post(
url, headers=headers, json=req.model_dump(), timeout=None
)
self.http_client.check_response_header(response)
request_id = self.http_client.response_request_id(response)
await self.http_client.check_response_header(response)
request_id = await self.http_client.response_request_id(response)
if stream:
client = AsyncSSEClient(response)
return Message(content=self._iterate_events(request_id, client.events()))
Expand Down Expand Up @@ -164,7 +164,7 @@ async def upload_local_file(self, conversation_id, local_file_path: str) -> str:
response = await self.http_client.session.post(
url, data=multipart_form_data, headers=headers
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.FileUploadResponse(**data)
return resp.id
Expand Down
6 changes: 3 additions & 3 deletions python/core/console/appbuilder_client/async_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ async def __async_run_process__(self):
while not self._is_complete:
if not self._need_tool_call:
res = await self._run()
self.__event_process__(res)
await self.__event_process__(res)
else:
res = await self._submit_tool_output()
self.__event_process__(res)
await self.__event_process__(res)
yield res
if self._need_tool_call and self._is_complete:
self.reset_state()
await self.reset_state()

async def __event_process__(self, run_response):
"""
Expand Down
12 changes: 8 additions & 4 deletions python/tests/component_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
SKIP_COMPONENTS = [
]

V2_SKIP_COMPONENTS = [
"ASR",
]

# 白名单中的组件因历史原因,检查失败,但可以正常使用,因此加入白名单
COMPONENT_WHITE_LIST = [
"RagWithBaiduSearchPro",
Expand Down Expand Up @@ -75,10 +79,10 @@
def get_component_white_list():
return COMPONENT_WHITE_LIST

def get_components(components_list, import_prefix):
def get_components(components_list, import_prefix, skip_components):
components = {}
for component in components_list:
if component in SKIP_COMPONENTS:
if component in skip_components:
continue

try:
Expand All @@ -98,12 +102,12 @@ def get_components(components_list, import_prefix):

def get_all_components():
from appbuilder import __COMPONENTS__
all_components = get_components(__COMPONENTS__, "appbuilder.")
all_components = get_components(__COMPONENTS__, "appbuilder.", SKIP_COMPONENTS)
return all_components

def get_v2_components():
from appbuilder.core.components.v2 import __V2_COMPONENTS__
v2_components = get_components(__V2_COMPONENTS__, "appbuilder.core.components.v2.")
v2_components = get_components(__V2_COMPONENTS__, "appbuilder.core.components.v2.", V2_SKIP_COMPONENTS)
return v2_components

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_all_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from component_check import check_component_with_retry, write_error_data


@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "")
@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "")
class TestComponentManifestsAndToolEval(unittest.TestCase):
"""
组件manifests和tool_eval入参测试类
Expand Down
15 changes: 7 additions & 8 deletions python/tests/test_appbuilder_components_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def setUp(self):
无返回值。
"""
self.audio_file_url = "https://bj.bcebos.com/v1/appbuilder/asr_test.pcm?authorization=bce-auth-v1" \
"%2FALTAKGa8m4qCUasgoljdEDAzLm%2F2024-01-11T10%3A56%3A41Z%2F-1%2Fhost" \
"%2Fa6c4d2ca8a3f0259f4cae8ae3fa98a9f75afde1a063eaec04847c99ab7d1e411"
self.asr = appbuilder.ASR()
self.image_url = "https://bj.bcebos.com/v1/appbuilder/table_ocr_test.png?"\
"authorization=bce-auth-v1%2FALTAKGa8m4qCUasgoljdEDAzLm%2F2024-01-24T12%3A37%3A09Z%2F-1%2Fhost%2Fab528a5a9120d328dc6d18c6"\
"064079145ff4698856f477b820147768fc2187d3"
self.table_ocr = appbuilder.TableOCR()
self.play = appbuilder.Playground(prompt_template="你好,{name},我是{bot_name},{bot_name}是一个{bot_type},我可以{bot_function},你可以问我{bot_question}。", model="ERNIE-3.5-8K")
model_name = "ERNIE-3.5-8K"
secret_key = os.getenv('SECRET_KEY', None)
Expand All @@ -77,10 +77,9 @@ def test_trace(self):
tracer.start_trace()

# test asr run and tool_eval
raw_audio = requests.get(self.audio_file_url).content
inp = appbuilder.Message(content={"raw_audio": raw_audio})
out = self.asr.run(inp)
result = self.asr.tool_eval(name="asr", streaming=True, file_url=self.audio_file_url)
out = self.table_ocr.run(appbuilder.Message(content={"url": self.image_url}))
print(out)
result = self.table_ocr.tool_eval(name="asr", streaming=True, file_names=[self.image_url])
for res in result:
print(res)

Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from appbuilder.core.components.asr.model import ShortSpeechRecognitionRequest, ShortSpeechRecognitionResponse
import os

@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "")
@unittest.skip("测试API超限,暂时跳过")
class TestASRComponent(unittest.TestCase):
def setUp(self):
"""
Expand Down
13 changes: 7 additions & 6 deletions python/tests/test_async_appbuilder_client_toolcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ async def interrupt(self, run_context, run_response):
tool_call_id = tool_call.id
tool_res = self.get_current_weather(**tool_call.function.arguments)
# 蓝色打印
print("\033[1;34m", "-> 本地ToolCall结果: ", tool_res, "\033[0m\n")
print("\033[1;34m", "-> 本地ToolCallId: ", tool_call_id, "\033[0m")
print("\033[1;34m", "-> ToolCall结果: ", tool_res, "\033[0m\n")
tool_output.append(
{"tool_call_id": tool_call_id, "output": tool_res})
return tool_output
Expand Down Expand Up @@ -92,9 +93,10 @@ def test_appbuilder_client_tool_call(self):
}
]

appbuilder.logger.setLoglevel("ERROR")
appbuilder.logger.setLoglevel("DEBUG")

async def agent_run(client, conversation_id, query):
async def agent_run(client, query):
conversation_id = await client.create_conversation()
with await client.run_with_handler(
conversation_id=conversation_id,
query=query,
Expand All @@ -105,11 +107,10 @@ async def agent_run(client, conversation_id, query):

async def agent_handle():
client = appbuilder.AsyncAppBuilderClient(self.app_id)
conversation_id = await client.create_conversation()
task1 = asyncio.create_task(
agent_run(client, conversation_id, "北京的天气怎么样"))
agent_run(client, "北京的天气怎么样"))
task2 = asyncio.create_task(
agent_run(client, conversation_id, "上海的天气怎么样"))
agent_run(client, "上海的天气怎么样"))
await asyncio.gather(task1, task2)

await client.http_client.session.close()
Expand Down
60 changes: 32 additions & 28 deletions python/tests/test_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import unittest
import json
import asyncio

from appbuilder.core._client import HTTPClient, AsyncHTTPClient
from appbuilder.core._exception import *
Expand Down Expand Up @@ -100,34 +101,37 @@ def test_core_client_check_response_header(self):
HTTPClient.check_response_header(response)

def test_core_client_check_async_response_header(self):
# 测试各种response报错
response = AsyncResponse(
status_code=400,
headers={'Content-Type': 'application/json'},
text='{"code": 0, "message": "success"}'
)
with self.assertRaises(BadRequestException):
AsyncHTTPClient.check_response_header(response)

response.status = 403
with self.assertRaises(ForbiddenException):
AsyncHTTPClient.check_response_header(response)

response.status = 404
with self.assertRaises(NotFoundException):
AsyncHTTPClient.check_response_header(response)

response.status = 428
with self.assertRaises(PreconditionFailedException):
AsyncHTTPClient.check_response_header(response)

response.status = 500
with self.assertRaises(InternalServerErrorException):
AsyncHTTPClient.check_response_header(response)

response.status = 201
with self.assertRaises(BaseRPCException):
AsyncHTTPClient.check_response_header(response)
async def run_test():
# 测试各种response报错
response = AsyncResponse(
status_code=400,
headers={'Content-Type': 'application/json'},
text=lambda:asyncio.sleep(0) or '{"code": 0, "message": "success"}'
)
with self.assertRaises(BadRequestException):
await AsyncHTTPClient.check_response_header(response)

response.status = 403
with self.assertRaises(ForbiddenException):
await AsyncHTTPClient.check_response_header(response)

response.status = 404
with self.assertRaises(NotFoundException):
await AsyncHTTPClient.check_response_header(response)

response.status = 428
with self.assertRaises(PreconditionFailedException):
await AsyncHTTPClient.check_response_header(response)

response.status = 500
with self.assertRaises(InternalServerErrorException):
await AsyncHTTPClient.check_response_header(response)

response.status = 201
with self.assertRaises(BaseRPCException):
await AsyncHTTPClient.check_response_header(response)
loop = asyncio.get_event_loop()
loop.run_until_complete(run_test())

def test_core_client_check_response_json(self):
data = {
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from appbuilder.core._exception import InvalidRequestArgumentError
import os

@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "")
@unittest.skip("测试API超限,暂时跳过")
class TestTTS(unittest.TestCase):
def setUp(self):
self.tts = appbuilder.TTS()
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_v2_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from appbuilder.core.components.v2 import ASR
from appbuilder.core.components.v2.asr.component import _convert as convert

@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "")
@unittest.skip("测试API超限,暂时跳过")
class TestASR(unittest.TestCase):
def setUp(self):
self.audio_file_url = "https://bj.bcebos.com/v1/appbuilder/asr_test.pcm?authorization=bce-auth-v1" \
Expand Down

0 comments on commit 9d591e1

Please sign in to comment.