From 7129e1dec38a19c52225e1f998303cd0e6ebf5f5 Mon Sep 17 00:00:00 2001 From: yinjiaqi Date: Fri, 13 Dec 2024 13:14:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=81=A2=E5=A4=8D=E9=83=A8=E5=88=86Components?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E7=9A=84=E7=A7=81=E6=9C=89=E5=87=BD=E6=95=B0?= =?UTF-8?q?,=E6=9B=B4=E6=96=B0=E7=BB=84=E4=BB=B6chart=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E7=B1=BB=E5=9E=8Bkey=5Flist?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/core/component.py | 2 +- .../v2/llms/oral_query_generation/component.py | 8 ++++---- python/core/components/v2/table_ocr/component.py | 4 ++-- python/core/components/v2/text_to_image/component.py | 12 ++++++------ 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/core/component.py b/python/core/component.py index 4254bbef..3b892c61 100644 --- a/python/core/component.py +++ b/python/core/component.py @@ -537,7 +537,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra elif type == "image": key_list = ["filename", "url"] elif type == "chart": - key_list = ["filename", "url"] + key_list = ["type", "data"] elif type == "audio": key_list = ["filename", "url"] elif type == "plan": diff --git a/python/core/components/v2/llms/oral_query_generation/component.py b/python/core/components/v2/llms/oral_query_generation/component.py index 67359d98..55ca08f5 100644 --- a/python/core/components/v2/llms/oral_query_generation/component.py +++ b/python/core/components/v2/llms/oral_query_generation/component.py @@ -101,7 +101,7 @@ def __init__( OralQueryGenerationArgs, model=model, secret_key=secret_key, gateway=gateway, lazy_certification=lazy_certification) - def _regenerate_output(self, model_output, output_format): + def regenerate_output(self, model_output, output_format): """ 兼容老版本的输出格式 """ @@ -130,7 +130,7 @@ def _regenerate_output(self, model_output, output_format): regenerated_output = '\n'.join([f'{index}. {query}' for index, query in enumerate(queries, 1)]) return regenerated_output - def _completion(self, version, base_url, request, timeout: float = None, + def completion(self, version, base_url, request, timeout: float = None, retry: int = 0): r"""Send a byte array of an audio file to obtain the result of speech recognition.""" @@ -187,13 +187,13 @@ def run(self, message, query_type='全部', output_format='str', stream=False, t model_config = self.get_model_config(model_config_inputs) request = self.gene_request('', inputs, response_mode, user_id, model_config) - response = self._completion(self.version, self.base_url, request) + response = self.completion(self.version, self.base_url, request) if response.error_no != 0: raise AppBuilderServerException(service_err_code=response.error_no, service_err_message=response.error_msg) result = response.to_message() - result.content = self._regenerate_output(result.content, output_format) + result.content = self.regenerate_output(result.content, output_format) return result diff --git a/python/core/components/v2/table_ocr/component.py b/python/core/components/v2/table_ocr/component.py index 114c256e..915cdf92 100644 --- a/python/core/components/v2/table_ocr/component.py +++ b/python/core/components/v2/table_ocr/component.py @@ -152,7 +152,7 @@ def _check_service_error(request_id: str, data: dict): service_err_message=data.get("error_msg") ) - def _get_table_markdown(self, tables_result): + def get_table_markdown(self, tables_result): """ 将表格识别结果转换为Markdown格式。 @@ -222,7 +222,7 @@ def tool_eval(self, req.cell_contents = "false" resp, raw_data = self._recognize(req, request_id=traceid) tables_result = proto.Message.to_dict(resp)["tables_result"] - markdowns = self._get_table_markdown(tables_result) + markdowns = self.get_table_markdown(tables_result) result[file_name] = markdowns result = json.dumps(result, ensure_ascii=False) diff --git a/python/core/components/v2/text_to_image/component.py b/python/core/components/v2/text_to_image/component.py index 6da931ff..594cbf6b 100644 --- a/python/core/components/v2/text_to_image/component.py +++ b/python/core/components/v2/text_to_image/component.py @@ -271,7 +271,7 @@ def _recognize( while True: request = Text2ImageQueryRequest(task_id=taskId) - text2ImageQueryResponse, data = self._queryText2ImageData(request, request_id=request_id) + text2ImageQueryResponse, data = self.queryText2ImageData(request, request_id=request_id) if text2ImageQueryResponse.data.task_progress is not None: task_progress = float(text2ImageQueryResponse.data.task_progress) if math.isclose(1.0, task_progress, rel_tol=1e-9, abs_tol=0.0): @@ -285,11 +285,11 @@ def _recognize( time.sleep(task_request_time) task_request_time += 1 - img_urls = self._extract_img_urls(text2ImageQueryResponse) + img_urls = self.extract_img_urls(text2ImageQueryResponse) return img_urls, data - def _queryText2ImageData( + def queryText2ImageData( self, request: Text2ImageQueryRequest, timeout: float = None, @@ -321,11 +321,11 @@ def _queryText2ImageData( data = response.json() self.http_client.check_response_json(data) request_id = self.http_client.response_request_id(response) - self.__class__._check_service_error(request_id, data) + self.__class__.check_service_error(request_id, data) response = Text2ImageQueryResponse(**data) return response, data - def _extract_img_urls(self, response: Text2ImageQueryResponse): + def extract_img_urls(self, response: Text2ImageQueryResponse): """ 从作画生成的返回结果中提取图片url。 @@ -347,7 +347,7 @@ def _extract_img_urls(self, response: Text2ImageQueryResponse): return img_urls @staticmethod - def _check_service_error(request_id: str, data: dict): + def check_service_error(request_id: str, data: dict): """ 检查服务错误信息