Skip to content

Commit

Permalink
恢复部分Components组件的私有函数,更新组件chart数据类型key_list
Browse files Browse the repository at this point in the history
  • Loading branch information
yinjiaqi authored and yinjiaqi committed Dec 13, 2024
1 parent 0bf0ed7 commit 7129e1d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra
elif type == "image":
key_list = ["filename", "url"]
elif type == "chart":
key_list = ["filename", "url"]
key_list = ["type", "data"]
elif type == "audio":
key_list = ["filename", "url"]
elif type == "plan":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
OralQueryGenerationArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)

def _regenerate_output(self, model_output, output_format):
def regenerate_output(self, model_output, output_format):
"""
兼容老版本的输出格式
"""
Expand Down Expand Up @@ -130,7 +130,7 @@ def _regenerate_output(self, model_output, output_format):
regenerated_output = '\n'.join([f'{index}. {query}' for index, query in enumerate(queries, 1)])
return regenerated_output

def _completion(self, version, base_url, request, timeout: float = None,
def completion(self, version, base_url, request, timeout: float = None,
retry: int = 0):
r"""Send a byte array of an audio file to obtain the result of speech recognition."""

Expand Down Expand Up @@ -187,13 +187,13 @@ def run(self, message, query_type='全部', output_format='str', stream=False, t
model_config = self.get_model_config(model_config_inputs)

request = self.gene_request('', inputs, response_mode, user_id, model_config)
response = self._completion(self.version, self.base_url, request)
response = self.completion(self.version, self.base_url, request)

if response.error_no != 0:
raise AppBuilderServerException(service_err_code=response.error_no, service_err_message=response.error_msg)

result = response.to_message()
result.content = self._regenerate_output(result.content, output_format)
result.content = self.regenerate_output(result.content, output_format)

return result

Expand Down
4 changes: 2 additions & 2 deletions python/core/components/v2/table_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _check_service_error(request_id: str, data: dict):
service_err_message=data.get("error_msg")
)

def _get_table_markdown(self, tables_result):
def get_table_markdown(self, tables_result):
"""
将表格识别结果转换为Markdown格式。
Expand Down Expand Up @@ -222,7 +222,7 @@ def tool_eval(self,
req.cell_contents = "false"
resp, raw_data = self._recognize(req, request_id=traceid)
tables_result = proto.Message.to_dict(resp)["tables_result"]
markdowns = self._get_table_markdown(tables_result)
markdowns = self.get_table_markdown(tables_result)
result[file_name] = markdowns

result = json.dumps(result, ensure_ascii=False)
Expand Down
12 changes: 6 additions & 6 deletions python/core/components/v2/text_to_image/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _recognize(

while True:
request = Text2ImageQueryRequest(task_id=taskId)
text2ImageQueryResponse, data = self._queryText2ImageData(request, request_id=request_id)
text2ImageQueryResponse, data = self.queryText2ImageData(request, request_id=request_id)
if text2ImageQueryResponse.data.task_progress is not None:
task_progress = float(text2ImageQueryResponse.data.task_progress)
if math.isclose(1.0, task_progress, rel_tol=1e-9, abs_tol=0.0):
Expand All @@ -285,11 +285,11 @@ def _recognize(
time.sleep(task_request_time)
task_request_time += 1

img_urls = self._extract_img_urls(text2ImageQueryResponse)
img_urls = self.extract_img_urls(text2ImageQueryResponse)

return img_urls, data

def _queryText2ImageData(
def queryText2ImageData(
self,
request: Text2ImageQueryRequest,
timeout: float = None,
Expand Down Expand Up @@ -321,11 +321,11 @@ def _queryText2ImageData(
data = response.json()
self.http_client.check_response_json(data)
request_id = self.http_client.response_request_id(response)
self.__class__._check_service_error(request_id, data)
self.__class__.check_service_error(request_id, data)
response = Text2ImageQueryResponse(**data)
return response, data

def _extract_img_urls(self, response: Text2ImageQueryResponse):
def extract_img_urls(self, response: Text2ImageQueryResponse):
"""
从作画生成的返回结果中提取图片url。
Expand All @@ -347,7 +347,7 @@ def _extract_img_urls(self, response: Text2ImageQueryResponse):
return img_urls

@staticmethod
def _check_service_error(request_id: str, data: dict):
def check_service_error(request_id: str, data: dict):
"""
检查服务错误信息
Expand Down

0 comments on commit 7129e1d

Please sign in to comment.