Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

恢复部分Components组件的私有函数,更新组件chart数据类型key_list #662

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra
elif type == "image":
key_list = ["filename", "url"]
elif type == "chart":
key_list = ["filename", "url"]
key_list = ["type", "data"]
elif type == "audio":
key_list = ["filename", "url"]
elif type == "plan":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
OralQueryGenerationArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)

def _regenerate_output(self, model_output, output_format):
def regenerate_output(self, model_output, output_format):
"""
兼容老版本的输出格式
"""
Expand Down Expand Up @@ -130,7 +130,7 @@ def _regenerate_output(self, model_output, output_format):
regenerated_output = '\n'.join([f'{index}. {query}' for index, query in enumerate(queries, 1)])
return regenerated_output

def _completion(self, version, base_url, request, timeout: float = None,
def completion(self, version, base_url, request, timeout: float = None,
retry: int = 0):
r"""Send a byte array of an audio file to obtain the result of speech recognition."""

Expand Down Expand Up @@ -187,13 +187,13 @@ def run(self, message, query_type='全部', output_format='str', stream=False, t
model_config = self.get_model_config(model_config_inputs)

request = self.gene_request('', inputs, response_mode, user_id, model_config)
response = self._completion(self.version, self.base_url, request)
response = self.completion(self.version, self.base_url, request)

if response.error_no != 0:
raise AppBuilderServerException(service_err_code=response.error_no, service_err_message=response.error_msg)

result = response.to_message()
result.content = self._regenerate_output(result.content, output_format)
result.content = self.regenerate_output(result.content, output_format)

return result

Expand Down
4 changes: 2 additions & 2 deletions python/core/components/v2/table_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _check_service_error(request_id: str, data: dict):
service_err_message=data.get("error_msg")
)

def _get_table_markdown(self, tables_result):
def get_table_markdown(self, tables_result):
"""
将表格识别结果转换为Markdown格式。

Expand Down Expand Up @@ -222,7 +222,7 @@ def tool_eval(self,
req.cell_contents = "false"
resp, raw_data = self._recognize(req, request_id=traceid)
tables_result = proto.Message.to_dict(resp)["tables_result"]
markdowns = self._get_table_markdown(tables_result)
markdowns = self.get_table_markdown(tables_result)
result[file_name] = markdowns

result = json.dumps(result, ensure_ascii=False)
Expand Down
12 changes: 6 additions & 6 deletions python/core/components/v2/text_to_image/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _recognize(

while True:
request = Text2ImageQueryRequest(task_id=taskId)
text2ImageQueryResponse, data = self._queryText2ImageData(request, request_id=request_id)
text2ImageQueryResponse, data = self.queryText2ImageData(request, request_id=request_id)
if text2ImageQueryResponse.data.task_progress is not None:
task_progress = float(text2ImageQueryResponse.data.task_progress)
if math.isclose(1.0, task_progress, rel_tol=1e-9, abs_tol=0.0):
Expand All @@ -285,11 +285,11 @@ def _recognize(
time.sleep(task_request_time)
task_request_time += 1

img_urls = self._extract_img_urls(text2ImageQueryResponse)
img_urls = self.extract_img_urls(text2ImageQueryResponse)

return img_urls, data

def _queryText2ImageData(
def queryText2ImageData(
self,
request: Text2ImageQueryRequest,
timeout: float = None,
Expand Down Expand Up @@ -321,11 +321,11 @@ def _queryText2ImageData(
data = response.json()
self.http_client.check_response_json(data)
request_id = self.http_client.response_request_id(response)
self.__class__._check_service_error(request_id, data)
self.__class__.check_service_error(request_id, data)
response = Text2ImageQueryResponse(**data)
return response, data

def _extract_img_urls(self, response: Text2ImageQueryResponse):
def extract_img_urls(self, response: Text2ImageQueryResponse):
"""
从作画生成的返回结果中提取图片url。

Expand All @@ -347,7 +347,7 @@ def _extract_img_urls(self, response: Text2ImageQueryResponse):
return img_urls

@staticmethod
def _check_service_error(request_id: str, data: dict):
def check_service_error(request_id: str, data: dict):
"""
检查服务错误信息

Expand Down
Loading