From a746f08baf224afe0c1c1df5d672e6f424b89032 Mon Sep 17 00:00:00 2001 From: peiwenYe <963623403@qq.com> Date: Mon, 23 Dec 2024 15:47:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9reference=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E4=BF=9D=E7=95=99=E5=AD=97=E6=AE=B5=EF=BC=9B=E7=BB=84=E4=BB=B6?= =?UTF-8?q?=E6=A0=87=E5=87=86=E5=8C=96=E5=8D=95=E6=B5=8B=E6=A1=86=E6=9E=B6?= =?UTF-8?q?=E6=9B=B4=E6=96=B0:=20=E6=9B=B4=E6=96=B0=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?=E5=8F=98=E9=87=8F=EF=BC=8C=E5=A2=9E=E5=8A=A0tool=5Feval?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=92=8Cmanifests=E5=8C=B9=E9=85=8D=E6=80=A7?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=20(#680)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化组件标准化单测框架:更新系统变量,增加tool_eval参数和manifests匹配性检查 * 组件标准化manifests更改回滚 * 修改references类型的保留字段 * 修改manifests改动对应的单测 * 修改manifests改动对应的单测 --------- Co-authored-by: yepeiwen01 --- python/core/component.py | 13 +- .../components/v2/handwrite_ocr/component.py | 15 +- .../components/v2/mix_card_ocr/component.py | 24 +- .../components/v2/qrcode_ocr/component.py | 15 +- .../core/components/v2/tree_mind/component.py | 13 + python/tests/component_check.py | 249 +++++++++++------- python/tests/component_tool_eval_cases.py | 6 +- python/tests/test_all_components.py | 99 ++----- python/tests/test_base_component.py | 3 + python/tests/test_v2_handwrite_ocr.py | 2 +- python/tests/test_v2_mix_card_ocr.py | 4 +- python/tests/test_v2_qrcode_ocr.py | 2 +- 12 files changed, 204 insertions(+), 241 deletions(-) diff --git a/python/core/component.py b/python/core/component.py index 23bcbac7..5628ed9f 100644 --- a/python/core/component.py +++ b/python/core/component.py @@ -83,17 +83,11 @@ class OralText(BaseModel, extra='allow'): class References(BaseModel, extra='allow'): type: str = Field(default="", description="类型") - resource_type: str = Field(default="", description="资源类型") - icon: str = Field(default="", description="站点图标") - site_name: str = Field(default="", description="站点名") source: str = Field(default="", description="来源") doc_id: str = Field(default="", description="文档id") title: str = Field(default="", description="标题") content: str = Field(default="", description="内容") - image_content: str = Field(default="", description="图片内容") - mock_id: Optional[str] = Field(default="", description="模拟数据id") - image_url: str = Field(default="", description="图片url") - video_url: str = Field(default="", description="视频url") + extra: Optional[dict] = Field(default={}, description="其他信息") class Image(BaseModel, extra='allow'): @@ -548,8 +542,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra elif type == "files": key_list = ["filename", "url"] elif type == "references": - key_list = ["type", "resource_type", "icon", "site_name", "source", - "doc_id", "title", "content", "image_content", "image_url", "video_url"] + key_list = ["type", "source", "doc_id", "title", "content"] elif type == "image": key_list = ["filename", "url"] elif type == "chart": @@ -562,7 +555,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra key_list = ["thought", "name", "arguments"] else: raise ValueError("Unknown type: {}".format(type)) - # assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list) + assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list) else: raise ValueError("text must be str or dict") diff --git a/python/core/components/v2/handwrite_ocr/component.py b/python/core/components/v2/handwrite_ocr/component.py index 825666ed..06a96086 100644 --- a/python/core/components/v2/handwrite_ocr/component.py +++ b/python/core/components/v2/handwrite_ocr/component.py @@ -59,13 +59,6 @@ class HandwriteOCR(Component): "type": "string" }, "description": "待识别文件的文件名" - }, - "file_urls": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "待识别文件的url下载地址" } }, "required": ["file_names"] @@ -114,13 +107,11 @@ def run(self, message: Message, timeout: float = None, retry: int = 0) -> Messag @components_run_stream_trace def tool_eval(self, file_names: Optional[list] = [], - file_urls: Optional[dict] = {}, **kwargs): """ 工具评估函数 Args: file_names (Optional[list]): 待识别文件的文件名列表 - file_urls (Optional[dict]): 待识别文件的url下载地址字典 **kwargs: 其他参数 Raises: @@ -133,12 +124,10 @@ def tool_eval(self, result = "" sys_file_names = file_names - sys_file_urls = file_urls - if not sys_file_names: sys_file_names = kwargs.get('_sys_file_names', []) - if not sys_file_urls: - sys_file_urls = kwargs.get('_sys_file_urls', {}) + + sys_file_urls = kwargs.get('_sys_file_urls', {}) for file_name in sys_file_names: if utils.is_url(file_name): diff --git a/python/core/components/v2/mix_card_ocr/component.py b/python/core/components/v2/mix_card_ocr/component.py index 0f70387c..1c700dad 100644 --- a/python/core/components/v2/mix_card_ocr/component.py +++ b/python/core/components/v2/mix_card_ocr/component.py @@ -63,13 +63,6 @@ class MixCardOCR(Component): "type": "string" }, "description": "待识别文件的文件名" - }, - "file_urls": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "待识别文件的下载URL" } }, "required": ["file_names"] @@ -77,6 +70,7 @@ class MixCardOCR(Component): } ] + @HTTPClient.check_param @components_run_trace def run(self, message: Message, timeout: float = None, retry: int = 0) -> Message: @@ -168,22 +162,16 @@ def _check_service_error(request_id: str, data: dict): @components_run_stream_trace def tool_eval(self, file_names: Optional[list] = [], - file_urls: Optional[dict] = {}, **kwargs): """ 对指定文件进行OCR识别。 Args: - name (str): API名称。 - streaming (bool): 是否流式输出。如果为True,则逐个返回识别结果;如果为False,则一次性返回所有识别结果。 + file_names (Optional[List], optional): 要识别的文件名列表。 **kwargs: 其他参数。 Returns: - 如果streaming为False,则返回包含所有识别结果的JSON字符串。 - 如果streaming为True,则逐个返回包含识别结果的字典,每个字典包含以下字段: - type (str): 消息类型,固定为"text"。 - text (str): 识别结果的JSON字符串。 - visible_scope (str): 消息可见范围,可以是"llm"或"user"。 + ComponentOutput: 识别结果。 Raises: InvalidRequestArgumentError: 如果请求格式错误,即文件URL不存在时抛出。 @@ -194,12 +182,10 @@ def tool_eval(self, traceid = kwargs.get("_sys_traceid", "") sys_file_names = file_names - sys_file_urls = file_urls - if not sys_file_names: sys_file_names = kwargs.get("_sys_file_names", []) - if not sys_file_urls: - sys_file_urls = kwargs.get("_sys_file_urls", {}) + + sys_file_urls = kwargs.get("_sys_file_urls", {}) for file_name in sys_file_names: if utils.is_url(file_name): diff --git a/python/core/components/v2/qrcode_ocr/component.py b/python/core/components/v2/qrcode_ocr/component.py index aaf7d79a..a38c7cc6 100644 --- a/python/core/components/v2/qrcode_ocr/component.py +++ b/python/core/components/v2/qrcode_ocr/component.py @@ -66,13 +66,6 @@ class QRcodeOCR(Component): "location": { "type": "string", "description": "是否输出二维码/条形码位置信息" - }, - "file_urls": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "待识别文件的URL下载地址" } }, "required": ["file_names"] @@ -164,14 +157,13 @@ def _check_service_error(request_id: str, data: dict): ) @components_run_stream_trace - def tool_eval(self, file_names:Optional[list]=[], location: Optional[str]="false", file_urls:Optional[dict]={}, **kwargs): + def tool_eval(self, file_names:Optional[list]=[], location: Optional[str]="false", **kwargs): """ ToolEval方法,用于执行二维码识别操作。 Args: file_names (list, 可选): 待识别文件的文件名列表。 location (str, 可选): 是否需要返回二维码位置信息,默认为 "false"。 - file_urls (dict, 可选): 待识别文件的URL下载地址字典,格式为 {"filename": "url"}。 Yields: ComponentOutput: 识别结果,包含识别到的二维码信息。 @@ -180,13 +172,10 @@ def tool_eval(self, file_names:Optional[list]=[], location: Optional[str]="false traceid = kwargs.get("_sys_traceid", "") # file_name sys_file_names = file_names - sys_file_urls = file_urls - if not sys_file_names: sys_file_names = kwargs.get("_sys_file_names", []) - if not sys_file_urls: - sys_file_urls = kwargs.get("_sys_file_urls", {}) + sys_file_urls = kwargs.get("_sys_file_urls", {}) for file_name in sys_file_names: if utils.is_url(file_name): diff --git a/python/core/components/v2/tree_mind/component.py b/python/core/components/v2/tree_mind/component.py index 17b81c0b..d3f56e4c 100644 --- a/python/core/components/v2/tree_mind/component.py +++ b/python/core/components/v2/tree_mind/component.py @@ -15,6 +15,7 @@ r"""树图工具""" import json +from urllib.parse import urlparse, unquote from typing import Dict, List, Optional, Any from appbuilder.core.message import Message from appbuilder.core._client import HTTPClient @@ -83,6 +84,17 @@ def _post(self, query, **kwargs): img_link = treemind_response.info.downloadInfo.fileInfo.pic return img_link, jump_link + @staticmethod + def get_filename_from_url(url): + """从给定URL中提取文件名""" + parsed_url = urlparse(url) + # 提取路径部分 + path = parsed_url.path + # 从路径中获取文件名 + filename = path.split('/')[-1] + # 解码URL编码的文件名 + return unquote(filename) + @components_run_stream_trace def tool_eval( self, @@ -115,6 +127,7 @@ def tool_eval( img_link_result = self.create_output( type="image", text={ + "filename": self.get_filename_from_url(img_link), "url": img_link }, visible_scope='all', diff --git a/python/tests/component_check.py b/python/tests/component_check.py index 52d919a4..0351512c 100644 --- a/python/tests/component_check.py +++ b/python/tests/component_check.py @@ -1,10 +1,8 @@ -import os import json +import os import inspect -import time -from jsonschema import validate, ValidationError, SchemaError +from jsonschema import validate from pydantic import BaseModel -from typing import Generator from appbuilder.utils.func_utils import Singleton from appbuilder.tests.component_schemas import type_to_json_schemas from appbuilder.utils.json_schema_to_model import json_schema_to_pydantic_model @@ -40,12 +38,15 @@ def register_rule(self, rule_name: str, rule_obj: RuleBase): def remove_rule(self, rule_name: str): del self.rules[rule_name] - def notify(self, component_cls) -> tuple[bool, list]: + def notify(self, component_cls, component_case) -> tuple[bool, list]: check_pass = True check_details = {} reasons = [] for rule_name, rule_obj in self.rules.items(): - res = rule_obj.check(component_cls) + if rule_name == "ToolEvalOutputJsonRule": + res = rule_obj.check(component_cls, component_case) + else: + res = rule_obj.check(component_cls) check_details[rule_name] = res if res.check_result == False: check_pass = False @@ -63,53 +64,40 @@ class ManifestValidRule(RuleBase): def __init__(self, **kwargs): super().__init__() self.rule_name = "ManifestValidRule" - self.component_tool_eval_cases = kwargs.get("component_tool_eval_cases", {}) - def check(self, component_cls) -> CheckInfo: + def check(self, component_obj) -> CheckInfo: check_pass_flag = True invalid_details = [] - component_cls_name = component_cls.__name__ - if component_cls_name not in self.component_tool_eval_cases: - invalid_details.append("{} 没有添加测试case到 component_tool_eval_cases 中".format(component_cls_name)) - else: - component_case = self.component_tool_eval_cases[component_cls_name]() - envs = component_case.envs() - os.environ.update(envs) - init_args = component_case.init_args() - try: - component_obj = component_cls(**init_args) - if not hasattr(component_obj, "manifests"): - raise ValueError("No manifests found") - manifests = component_obj.manifests - # NOTE(暂时检查manifest中的第一个mainfest) - if not manifests or len(manifests) == 0: - raise ValueError("No manifests found") - manifest = manifests[0] - tool_name = manifest['name'] - tool_desc = manifest['description'] - schema = manifest["parameters"] - schema["title"] = tool_name - # 第一步,将json schema转换为pydantic模型 - pydantic_model = json_schema_to_pydantic_model(schema, tool_name) - check_to_json = pydantic_model.schema_json() - json_to_dict = json.loads(check_to_json) - - if "properties" in schema: - properties = schema["properties"] - for key, value in properties.items(): - if "type" not in value: - invalid_details.append("\'type' must be in properties item: {}".format(key)) - if "description" not in value: - invalid_details.append("\'description' must be in properties item: {}".format(key)) - - except Exception as e: - print(e) - check_pass_flag = False - invalid_details.append(str(e)) - - for env in envs: - os.environ.pop(env) + try: + if not hasattr(component_obj, "manifests"): + raise ValueError("No manifests found") + manifests = component_obj.manifests + # NOTE(暂时检查manifest中的第一个mainfest) + if not manifests or len(manifests) == 0: + raise ValueError("No manifests found") + manifest = manifests[0] + tool_name = manifest['name'] + tool_desc = manifest['description'] + schema = manifest["parameters"] + schema["title"] = tool_name + # 第一步,将json schema转换为pydantic模型 + pydantic_model = json_schema_to_pydantic_model(schema, tool_name) + check_to_json = pydantic_model.schema_json() + json_to_dict = json.loads(check_to_json) + + if "properties" in schema: + properties = schema["properties"] + for key, value in properties.items(): + if "type" not in value: + invalid_details.append("\'type' must be in properties item: {}".format(key)) + if "description" not in value: + invalid_details.append("\'description' must be in properties item: {}".format(key)) + + except Exception as e: + print(e) + check_pass_flag = False + invalid_details.append(str(e)) if len(invalid_details) > 0: check_pass_flag = False @@ -137,14 +125,14 @@ def __init__(self): self.rule_name = "MainfestMatchToolEvalRule" - def check(self, component_cls) -> CheckInfo: + def check(self, component_obj) -> CheckInfo: check_pass_flag = True invalid_details = [] try: - if not hasattr(component_cls, "manifests"): + if not hasattr(component_obj, "manifests"): raise ValueError("No manifests found") - manifests = component_cls.manifests + manifests = component_obj.manifests # NOTE(暂时检查manifest中的第一个mainfest) if not manifests or len(manifests) == 0: raise ValueError("No manifests found") @@ -158,7 +146,7 @@ def check(self, component_cls) -> CheckInfo: # 交互检查 tool_eval_input_params = [] print("required_params: {}".format(manifest_var)) - signature = inspect.signature(component_cls.tool_eval) + signature = inspect.signature(component_obj.tool_eval) ileagal_params = [] for param_name, param in signature.parameters.items(): if param_name == 'kwargs' or param_name == 'args' or param_name == 'self': @@ -193,10 +181,6 @@ def check(self, component_cls) -> CheckInfo: check_detail=",".join(invalid_details)) - - - - class ToolEvalInputNameRule(RuleBase): """ 检查tool_eval的输入参数中,是否包含系统保留的输入名称 @@ -222,10 +206,15 @@ def __init__(self): "_sys_custom_variables", "_sys_thought_model_config", "_sys_rag_model_config", + "_sys_parent_span_id", + "_sys_span_id", + "_sys_memory", + "_sys_code_execution_endpoint", + "_sys_session_id" ] - def check(self, component_cls) -> CheckInfo: - tool_eval_signature = inspect.signature(component_cls.__init__) + def check(self, component_obj) -> CheckInfo: + tool_eval_signature = inspect.signature(component_obj.tool_eval) params = tool_eval_signature.parameters invalid_details = [] check_pass_flag = True @@ -250,7 +239,6 @@ class ToolEvalOutputJsonRule(RuleBase): def __init__(self, **kwargs): super().__init__() self.rule_name = 'ToolEvalOutputJsonRule' - self.component_tool_eval_cases = kwargs.get("component_tool_eval_cases") def _check_pre_format(self, outputs): invalid_details = [] @@ -351,42 +339,26 @@ def _check_text_and_code(self, component_case, output_dict): else: return [] - def check(self, component_cls) -> CheckInfo: + def check(self, component_obj, component_case) -> CheckInfo: invalid_details = [] - component_cls_name = component_cls.__name__ - if component_cls_name not in self.component_tool_eval_cases: - invalid_details.append("{} 没有添加测试case到 component_tool_eval_cases 中".format(component_cls_name)) - else: - component_case = self.component_tool_eval_cases[component_cls_name]() - - envs = {} - if hasattr(component_case, "envs"): - envs = component_case.envs() - os.environ.update(envs) - - input_dict = component_case.inputs() - init_args = component_case.init_args() - component_obj = component_cls(**init_args) - output_json_schemas = component_case.schemas() - - try: - stream_output_dict = {"text": "", "oral_text":"", "code": ""} - stream_outputs = component_obj.tool_eval(**input_dict) - for stream_output in stream_outputs: - iter_invalid_detail = self._check_jsonschema(stream_output.model_dump(), output_json_schemas) - invalid_details.extend(iter_invalid_detail) - iter_output_dict = self._gather_iter_outputs(stream_output) - stream_output_dict["text"] += iter_output_dict["text"] - stream_output_dict["oral_text"] += iter_output_dict["oral_text"] - stream_output_dict["code"] += iter_output_dict["code"] - if len(invalid_details) == 0: - invalid_details.extend(self._check_text_and_code(component_case, stream_output_dict)) - except Exception as e: - invalid_details.append("ToolEval执行失败: {}".format(e)) - - for env in envs: - os.environ.pop(env) + input_dict = component_case.inputs() + output_json_schemas = component_case.schemas() + + try: + stream_output_dict = {"text": "", "oral_text":"", "code": ""} + stream_outputs = component_obj.tool_eval(**input_dict) + for stream_output in stream_outputs: + iter_invalid_detail = self._check_jsonschema(stream_output.model_dump(), output_json_schemas) + invalid_details.extend(iter_invalid_detail) + iter_output_dict = self._gather_iter_outputs(stream_output) + stream_output_dict["text"] += iter_output_dict["text"] + stream_output_dict["oral_text"] += iter_output_dict["oral_text"] + stream_output_dict["code"] += iter_output_dict["code"] + if len(invalid_details) == 0: + invalid_details.extend(self._check_text_and_code(component_case, stream_output_dict)) + except Exception as e: + invalid_details.append("ToolEval执行失败: {}".format(e)) if len(invalid_details) > 0: return CheckInfo( @@ -400,6 +372,91 @@ def check(self, component_cls) -> CheckInfo: check_detail="") -def register_component_check_rule(rule_name: str, rule_cls: RuleBase, init_args={}): +def register_component_check_rule(rule_name: str, rule_cls: RuleBase): component_checker = ComponentCheckBase() - component_checker.register_rule(rule_name, rule_cls(**init_args)) \ No newline at end of file + component_checker.register_rule(rule_name, rule_cls()) + + +def check_component_with_retry(component_import_res_tuple): + """ + 使用重试机制检查组件。测试用例失败后会重试两次。 + + Args: + component_import_res_tuple (tuple): 包含组件和导入结果的元组。 + + Returns: + list: 包含错误信息的数据列表。 + + """ + component, import_res, component_case_cls = component_import_res_tuple + component_check_base = ComponentCheckBase() + if inspect.isclass(component): + component_name = component.__name__ + else: + component_name = component + error_data = [] + max_retries = 2 # 设置最大重试次数 + attempts = 0 + + while attempts <= max_retries: + if import_res["import_error"] != "": + error_data.append({"Component Name": component_name, "Error Message": import_res["import_error"]}) + print("组件名称:{} 错误信息:{}".format(component_name, import_res["import_error"])) + break + + component_case = component_case_cls() + envs = component_case.envs() + os.environ.update(envs) + component_cls = import_res["obj"] + component_obj = component_cls(**component_case.init_args()) + + try: + # 此处的self.component_check_base.notify需要根据实际情况修改 + pass_check, reasons = component_check_base.notify(component_obj, component_case) # 示例修改 + reasons = list(set(reasons)) + if not pass_check: + error_data.append({"Component Name": component_name, "Error Message": ", ".join(reasons)}) + print("组件名称:{} 错误信息:{}".format(component_name, ", ".join(reasons))) + # 如果检查失败,增加尝试次数并重试 + attempts += 1 + if attempts <= max_retries: + print("组件名称:{} 将重试,当前尝试次数:{}".format(component_name, attempts)) + continue + # 如果检查通过,则退出循环 + break + except Exception as e: + error_data.append({"Component Name": component_name, "Error Message": str(e)}) + print("组件名称:{} 错误信息:{}".format(component_name, str(e))) + # 如果发生异常,增加尝试次数并重试 + attempts += 1 + if attempts <= max_retries: + print("组件名称:{} 将重试,当前尝试次数:{}".format(component_name, attempts)) + continue + + finally: + for env in envs: + os.environ.pop(env) + + return error_data + +def write_error_data(txt_file_path, error_df, error_stats): + """将组件错误信息写入文件 + + Args: + error_df (Union[pd.DataFrame, None]): 错误信息表格 + error_stats (dict): 错误统计信息 + """ + with open(txt_file_path, 'w') as file: + file.write("Component Name\tError Message\n") + for _, row in error_df.iterrows(): + file.write(f"{row['Component Name']}\t{row['Error Message']}\n") + file.write("\n错误统计信息:\n") + for error, count in error_stats.items(): + file.write(f"错误信息: {error}, 出现次数: {count}\n") + print(f"\n错误信息已写入: {txt_file_path}") + + +register_component_check_rule("ManifestValidRule", ManifestValidRule) +register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule) +register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule) +register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule) \ No newline at end of file diff --git a/python/tests/component_tool_eval_cases.py b/python/tests/component_tool_eval_cases.py index 9cea3b67..19f999cc 100644 --- a/python/tests/component_tool_eval_cases.py +++ b/python/tests/component_tool_eval_cases.py @@ -106,7 +106,7 @@ def inputs(self): "e74ab057ce26d50e966dc31ff083e6a9c41b" return { "file_names": ["text"], - "file_urls": {"text": image_url} + "_sys_file_urls": {"text": image_url} } def schemas(self): @@ -151,7 +151,7 @@ def inputs(self): "F677f93445fb65157bee11cd492ce213d5c56e7a41827e45ce7e32b083d195c8b" return { "file_names": ["text"], - "file_urls": {"text": image_url} + "_sys_file_urls": {"text": image_url} } def schemas(self): @@ -168,7 +168,7 @@ def inputs(self): "1865e4393da5a3515e90d72d81ef18296bd29598") return { "file_names": ["test"], - "file_urls": {"test": image_url} + "_sys_file_urls": {"test": image_url} } def schemas(self): diff --git a/python/tests/test_all_components.py b/python/tests/test_all_components.py index 00dc956c..728e7eb3 100644 --- a/python/tests/test_all_components.py +++ b/python/tests/test_all_components.py @@ -4,86 +4,12 @@ import pandas as pd import unittest import os -from appbuilder.tests.component_check import ComponentCheckBase -from appbuilder.tests.component_check import register_component_check_rule -from appbuilder.tests.component_check import ManifestValidRule, MainfestMatchToolEvalRule, ToolEvalInputNameRule, ToolEvalOutputJsonRule + from appbuilder.core._exception import AppbuilderBuildexException -from component_collector import get_all_components, get_v2_components, get_component_white_list from component_tool_eval_cases import component_tool_eval_cases +from component_collector import get_all_components, get_v2_components, get_component_white_list +from component_check import check_component_with_retry, write_error_data -register_component_check_rule("ManifestValidRule", ManifestValidRule, \ - {"component_tool_eval_cases": component_tool_eval_cases}) -register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule, {}) -register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule, {}) -register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule, \ - {"component_tool_eval_cases": component_tool_eval_cases}) - - -def check_component_with_retry(component_import_res_tuple): - """ - 使用重试机制检查组件。测试用例失败后会重试两次。 - - Args: - component_import_res_tuple (tuple): 包含组件和导入结果的元组。 - - Returns: - list: 包含错误信息的数据列表。 - - """ - component, import_res = component_import_res_tuple - component_check_base = ComponentCheckBase() - - error_data = [] - max_retries = 2 # 设置最大重试次数 - attempts = 0 - - while attempts <= max_retries: - if import_res["import_error"] != "": - error_data.append({"Component Name": component, "Error Message": import_res["import_error"]}) - print("组件名称:{} 错误信息:{}".format(component, import_res["import_error"])) - break - - component_obj = import_res["obj"] - try: - # 此处的self.component_check_base.notify需要根据实际情况修改 - pass_check, reasons = component_check_base.notify(component_obj) # 示例修改 - reasons = list(set(reasons)) - if not pass_check: - error_data.append({"Component Name": component, "Error Message": ", ".join(reasons)}) - print("组件名称:{} 错误信息:{}".format(component, ", ".join(reasons))) - # 如果检查失败,增加尝试次数并重试 - attempts += 1 - if attempts <= max_retries: - print("组件名称:{} 将重试,当前尝试次数:{}".format(component, attempts)) - continue - # 如果检查通过,则退出循环 - break - except Exception as e: - error_data.append({"Component Name": component, "Error Message": str(e)}) - print("组件名称:{} 错误信息:{}".format(component, str(e))) - # 如果发生异常,增加尝试次数并重试 - attempts += 1 - if attempts <= max_retries: - print("组件名称:{} 将重试,当前尝试次数:{}".format(component, attempts)) - continue - - return error_data - -def write_error_data(txt_file_path, error_df, error_stats): - """将组件错误信息写入文件 - - Args: - error_df (Union[pd.DataFrame, None]): 错误信息表格 - error_stats (dict): 错误统计信息 - """ - with open(txt_file_path, 'w') as file: - file.write("Component Name\tError Message\n") - for _, row in error_df.iterrows(): - file.write(f"{row['Component Name']}\t{row['Error Message']}\n") - file.write("\n错误统计信息:\n") - for error, count in error_stats.items(): - file.write(f"错误信息: {error}, 出现次数: {count}\n") - print(f"\n错误信息已写入: {txt_file_path}") @unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") class TestComponentManifestsAndToolEval(unittest.TestCase): @@ -110,7 +36,7 @@ def setUp(self) -> None: self.v2_components = get_v2_components() self.whitelist_components = get_component_white_list() - def _test_component(self, components, whitelist_components, txt_file_path): + def _test_component(self, components, component_cases, whitelist_components, txt_file_path): """测试所有组件的manifests和tool_eval入参 Args: 无 @@ -122,7 +48,16 @@ def _test_component(self, components, whitelist_components, txt_file_path): with Pool(processes=os.cpu_count()) as pool: # 使用pool.map来执行多进程 - results = pool.map(check_component_with_retry, components.items()) + args = [] + for component, import_res in components.items(): + if component not in component_cases: + error_data.append({"Component Name": component, "Error Message": "{} 没有添加测试case到 \ + component_tool_eval_cases 中".format(component)}) + continue + else: + args.append((component, import_res, component_tool_eval_cases[component])) + + results = pool.map(check_component_with_retry, args) # 合并每个进程返回的错误数据 for result in results: @@ -154,13 +89,13 @@ def _test_component(self, components, whitelist_components, txt_file_path): else: print("\n所有组件测试通过,无错误信息。") - def test_all_components(self): + def _test_all_components(self): """测试旧版本组件""" - self._test_component(self.all_components, self.whitelist_components, 'components_error_info.txt') + self._test_component(self.all_components, [], self.whitelist_components, 'components_error_info.txt') def test_v2_components(self): """测试v2版本组件""" - self._test_component(self.v2_components, [], 'v2_components_error_info.txt') + self._test_component(self.v2_components, component_tool_eval_cases, [], 'v2_components_error_info.txt') if __name__ == '__main__': diff --git a/python/tests/test_base_component.py b/python/tests/test_base_component.py index 94d5ff4e..a560cf78 100644 --- a/python/tests/test_base_component.py +++ b/python/tests/test_base_component.py @@ -32,6 +32,7 @@ def test_valid_output_with_dict(self): output8 = self.component.create_output(type="audio", text={"filename": "file.mp3", "url": "http://www.baidu.com"}) output9 = self.component.create_output(type="plan", text={"detail": "hello", "steps":[{"name": "1", "arguments": {"query": "a", "chat_history": "world"}}]}) output10 = self.component.create_output(type="function_call", text={"thought": "hello", "name": "AppBuilder", "arguments": {"query": "a", "chat_history": "world"}}) + output11 = self.component.create_output(type="references", text={"type": "engine", "doc_id": "1", "content": "hello, world", "title": "Have a nice day", "source": "bing", "extra": {"key": "value"}}) self.assertIsInstance(output1, ComponentOutput) self.assertIsInstance(output2, ComponentOutput) self.assertIsInstance(output3, ComponentOutput) @@ -42,6 +43,8 @@ def test_valid_output_with_dict(self): self.assertIsInstance(output8, ComponentOutput) self.assertIsInstance(output9, ComponentOutput) self.assertIsInstance(output10, ComponentOutput) + self.assertIsInstance(output11, ComponentOutput) + self.assertEqual(output11.content[0].text.extra["key"], "value") def test_valid_output_type_with_same_key(self): output1 = self.component.create_output(type="urls", text={"url": "http://www.baidu.com"}) diff --git a/python/tests/test_v2_handwrite_ocr.py b/python/tests/test_v2_handwrite_ocr.py index dfc8a839..059481f4 100644 --- a/python/tests/test_v2_handwrite_ocr.py +++ b/python/tests/test_v2_handwrite_ocr.py @@ -80,7 +80,7 @@ def test_tool_eval(self): next(result) result=self.handwrite_ocr.tool_eval( file_names=['test'], - file_urls={'test':self.image_url} + _sys_file_urls={'test':self.image_url} ) res=next(result) print(res) diff --git a/python/tests/test_v2_mix_card_ocr.py b/python/tests/test_v2_mix_card_ocr.py index 36101413..1ba67310 100644 --- a/python/tests/test_v2_mix_card_ocr.py +++ b/python/tests/test_v2_mix_card_ocr.py @@ -82,10 +82,8 @@ def test_tool_eval(self): with self.assertRaises(InvalidRequestArgumentError): next(result) result=self.mix_card_ocr.tool_eval( - name='name', - streaming=True, file_names=['test'], - file_urls={'test':self.image_url} + _sys_file_urls={'test':self.image_url} ) res=next(result) print("res: {}".format(res)) diff --git a/python/tests/test_v2_qrcode_ocr.py b/python/tests/test_v2_qrcode_ocr.py index ae8198d9..d1a4a845 100644 --- a/python/tests/test_v2_qrcode_ocr.py +++ b/python/tests/test_v2_qrcode_ocr.py @@ -145,7 +145,7 @@ def test_tool_eval(self): print(msg) result=self.qrcode_ocr.tool_eval( file_names=['test'], - file_urls={'test':image_url} + _sys_file_urls={'test':image_url} ) res=next(result) print(res)