Skip to content

Commit

Permalink
Merge pull request #23 from ide-rea/master
Browse files Browse the repository at this point in the history
Decouple HTTPClient from Component
  • Loading branch information
guru4elephant authored Dec 27, 2023
2 parents ee80dbd + aec7ae8 commit c41b7d1
Show file tree
Hide file tree
Showing 19 changed files with 164 additions and 224 deletions.
3 changes: 2 additions & 1 deletion appbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_version(self):
from .core.components.landmark_recognize.component import LandmarkRecognition
from .core.components.tts.component import TTS
from .core.components.extract_table.component import ExtractTableFromDoc
from .core.components.doc_parser.doc_parser import DocParser
from .core.components.doc_parser.doc_parser import DocParser, ParserConfig
from .core.components.doc_splitter.doc_splitter import DocSplitter
from .core.components.retriever.bes_retriever import BESRetriever
from .core.components.retriever.bes_retriever import BESVectorStoreIndex
Expand Down Expand Up @@ -114,6 +114,7 @@ def check_version(self):
'TTS',
"ExtractTableFromDoc",
"DocParser",
"ParserConfig",
"DocSplitter",
"BESRetriever",
"BESVectorStoreIndex",
Expand Down
57 changes: 40 additions & 17 deletions appbuilder/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,45 @@
"""Base client for interact with backend server"""

import os
from typing import Optional

import requests
from requests.adapters import HTTPAdapter, Retry

from appbuilder.core._exception import *
from appbuilder.core.constants import GATEWAY_URL


class BaseClient:
r"""
BaseClient class provide common method for interact with backend server.
"""
class HTTPClient:
r"""HTTPClient类,实现与后端服务交互的公共方法"""

def __init__(self, secret_key: str = "", gateway: str = ""):
r"""__init__ method.
:param secret_key: authorization token, if not set get from env variable.
:param gateway: backend server host.
:rtype:
"""
def __init__(self,
secret_key: Optional[str] = None,
gateway: str = ""
):
r"""HTTPClient初始化方法.
参数:
secret_key(str,可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway(str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
返回:
"""
self.secret_key = secret_key if secret_key else os.getenv("APPBUILDER_TOKEN", "")

if not self.secret_key:
raise ValueError("secret_key is empty, please pass a nonempty secret_key "
"or set a secret_key in environment variable")

self.gateway = gateway if gateway else os.getenv("GATEWAY_URL", "")
if not gateway and not os.getenv("GATEWAY_URL"):
self.gateway = GATEWAY_URL
else:
self.gateway = gateway if gateway else os.getenv("GATEWAY_URL", "")

# self.gateway = gateway or os.getenv("GATEWAY_URL", "https://api.xbuilder.baidu.com")
if not self.gateway.startswith("http"):
self.gateway = "https://" + gateway
self.gateway = "https://" + self.gateway
self.session = requests.sessions.Session()
self.retry = Retry(total=0, backoff_factor=0.1)
self.session.mount(self.gateway, HTTPAdapter(max_retries=self.retry))

@staticmethod
def check_response_header(response: requests.Response):
Expand All @@ -55,8 +65,8 @@ def check_response_header(response: requests.Response):
status_code = response.status_code
if status_code == requests.codes.ok:
return
request_id = response.headers.get("X-App-Engine-Request-Id", "")
message = "request_id={} , http status code is {}".format(request_id, status_code)
message = "request_id={} , http status code is {}, body is {}".format(
__class__.response_request_id(response), status_code, response.text)
if status_code == requests.codes.bad_request:
raise BadRequestException(message)
elif status_code == requests.codes.forbidden:
Expand All @@ -70,7 +80,7 @@ def check_response_header(response: requests.Response):
else:
raise BaseRPCException(message)

def service_url(self, sub_path: str, prefix=None):
def service_url(self, sub_path: str, prefix: str = None):
r"""service_url is a helper method for concatenate service url.
:param sub_path: service unique sub path.
:param prefix: service prefix.
Expand All @@ -88,3 +98,16 @@ def check_response_json(data: dict):
"""
if "code" in data and "message" in data and "requestId" in data:
raise AppBuilderServerException(data["requestId"], data["code"], data["message"])

def auth_header(self):
r"""auth_header is a helper method return auth info"""

if self.secret_key.startswith("Bearer "):
return {"X-Appbuilder-Authorization": self.secret_key}
else:
return {"X-Appbuilder-Authorization": "Bearer {}".format(self.secret_key)}

@staticmethod
def response_request_id(response: requests.Response):
r"""response_request_id is a helper method get unique request id"""
return response.headers.get("X-Appbuilder-Request-Id", "")
83 changes: 4 additions & 79 deletions appbuilder/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Component module include a Component class which is the
base class for customized Component class, define interface method like run() batch() etc.
subclass may choose to implement, also provide some simple helper method for interact with backend server."""
"""Component模块包括组件基类,用户自定义组件需要继承Component类,并至少实现run方法"""

import os
from enum import Enum

import requests
from pydantic import BaseModel
from requests.adapters import HTTPAdapter, Retry
from typing import Dict, List, Optional, Any

from appbuilder.core._exception import *
from appbuilder.core._client import HTTPClient
from appbuilder.core.message import Message
from appbuilder.core.constants import GATEWAY_URL


class ComponentArguments(BaseModel):
Expand Down Expand Up @@ -72,22 +66,9 @@ def __init__(self,
返回:
"""

self.meta = meta
self.secret_key = secret_key if secret_key else os.getenv("APPBUILDER_TOKEN", "")
if not self.secret_key:
raise ValueError("secret_key is empty, please pass a nonempty secret_key "
"or set a secret_key in environment variable")

if not gateway and not os.getenv("GATEWAY_URL"):
self.gateway = GATEWAY_URL
else:
self.gateway = gateway if gateway else os.getenv("GATEWAY_URL", "")

if not self.gateway.startswith("http"):
self.gateway = "https://" + self.gateway
self.s = requests.sessions.Session()
self.retry = Retry(total=0, backoff_factor=0.1)
self.s.mount(self.gateway, HTTPAdapter(max_retries=self.retry))
self.http_client = HTTPClient(secret_key, gateway)

def __call__(self, *inputs, **kwargs):
r"""implement __call__ method"""
Expand Down Expand Up @@ -123,59 +104,3 @@ def _trace(self, **data) -> None:
def _debug(self, **data) -> None:
r"""pass"""
pass

@staticmethod
def check_response_header(response: requests.Response):
r"""check_response_header is a helper method for check head status .
:param response: requests.Response.
:rtype:
"""
status_code = response.status_code
if status_code == requests.codes.ok:
return
message = "request_id={} , http status code is {}, body is {}".format(
__class__.response_request_id(response), status_code, response.text)
if status_code == requests.codes.bad_request:
raise BadRequestException(message)
elif status_code == requests.codes.forbidden:
raise ForbiddenException(message)
elif status_code == requests.codes.not_found:
raise NotFoundException(message)
elif status_code == requests.codes.precondition_required:
raise PreconditionFailedException(message)
elif status_code == requests.codes.internal_server_error:
raise InternalServerErrorException(message)
else:
raise BaseRPCException(message)

def service_url(self, sub_path: str, prefix: str = None):
r"""service_url is a helper method for concatenate service url.
:param sub_path: service unique sub path.
:param prefix: service prefix.
:rtype: str.
"""
# host + fix prefix + sub service path
prefix = prefix if prefix else "/rpc/2.0/cloud_hub"
return self.gateway + prefix + sub_path

@staticmethod
def check_response_json(data: dict):
r"""check_response_json is a helper method for check backend server response.
:param: dict, body response data.
:rtype: str.
"""
if "code" in data and "message" in data and "requestId" in data:
raise AppBuilderServerException(data["requestId"], data["code"], data["message"])

def auth_header(self):
r"""auth_header is a helper method return auth info"""

if self.secret_key.startswith("Bearer "):
return {"X-Appbuilder-Authorization": self.secret_key}
else:
return {"X-Appbuilder-Authorization": "Bearer {}".format(self.secret_key)}

@staticmethod
def response_request_id(response: requests.Response):
r"""response_request_id is a helper method get unique request id"""
return response.headers.get("X-Appbuilder-Request-Id", "")
14 changes: 7 additions & 7 deletions appbuilder/core/components/asr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,20 @@ def _recognize(self, request: ShortSpeechRecognitionRequest, timeout: float = No
obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。
"""
ContentType = "audio/" + request.format + ";rate=" + str(request.rate)
headers = self.auth_header()
headers = self.http_client.auth_header()
headers['content-type'] = ContentType
params = {
'dev_pid': request.dev_pid,
'cuid': request.cuid
}
if retry != self.retry.total:
self.retry.total = retry
response = self.s.post(self.service_url("/v1/bce/aip_speech/asrpro"), params=params, headers=headers, data=request.speech, timeout=timeout)
super().check_response_header(response)
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry
response = self.http_client.session.post(self.http_client.service_url("/v1/bce/aip_speech/asrpro"), params=params, headers=headers, data=request.speech, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
super().check_response_json(data)
self.http_client.check_response_json(data)
self.__class__._check_service_error(data)
request_id = self.response_request_id(response)
request_id = self.http_client.response_request_id(response)
response = ShortSpeechRecognitionResponse.from_json(payload=json.dumps(data))
response.request_id = request_id
return response
Expand Down
16 changes: 8 additions & 8 deletions appbuilder/core/components/dish_recognize/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ def _recognize(self, request: DishRecognitionRequest, timeout: float = None,
if not request.filter_threshold:
request.filter_threshold = 0.95
request_data = DishRecognitionRequest.to_dict(request)
if retry != self.retry.total:
self.retry.total = retry
headers = self.auth_header()
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry
headers = self.http_client.auth_header()
headers['content-type'] = 'application/x-www-form-urlencoded'

url = self.service_url("/v1/bce/aip/image-classify/v2/dish")
response = self.s.post(url, headers=headers, data=request_data, timeout=timeout)
url = self.http_client.service_url("/v1/bce/aip/image-classify/v2/dish")
response = self.http_client.session.post(url, headers=headers, data=request_data, timeout=timeout)

self.check_response_header(response)
self.http_client.check_response_header(response)
data = response.json()
self.check_response_json(data)
self.http_client.check_response_json(data)
if "error_code" in data and "error_msg" in data:
raise AppBuilderServerException(service_err_code=data["error_code"], service_err_message=data["error_msg"])
return DishRecognitionResponse(data, request_id=self.response_request_id(response))
return DishRecognitionResponse(data, request_id=self.http_client.response_request_id(response))
12 changes: 5 additions & 7 deletions appbuilder/core/components/doc_parser/doc_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,11 @@ def run(self, input_message: Message, return_raw=False) -> Message:
param["data"] = base64.b64encode(f.read()).decode()
param["name"] = os.path.basename(file_path)
payload = json.dumps({"file_list": [param]})
headers = {
"Authorization": self.secret_key,
"Content-Type": "application/json"
}
response = self.s.post(url=self.service_url(self.base_url), headers=headers, data=payload)
self.check_response_header(response)
self.check_response_json(response.json())
headers = self.http_client.auth_header()
headers["Content-Type"] = "application/json"
response = self.http_client.session.post(url=self.http_client.service_url(self.base_url), headers=headers, data=payload)
self.http_client.check_response_header(response)
self.http_client.check_response_json(response.json())
response = response.json()
if response["error_code"] != 0:
logger.error("doc parser service log_id {} err {}".format(response["log_id"], response["error_msg"]))
Expand Down
15 changes: 7 additions & 8 deletions appbuilder/core/components/doc_splitter/doc_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,17 @@ def run(self, message: Message):
if not isinstance(paser_res, ParseResult):
raise ValueError("message.content type must be a ParseResult")

headers = {
"Authorization": self.secret_key if self.secret_key else os.getenv("APPBUILDER_TOKEN"),
"Content-Type": "application/json"
}
headers = self.http_client.auth_header()
headers["Content-Type"] = "application/json"

chunk_splitter_remote_params = {"xmind_res": paser_res.raw, "max_segment_length": self.max_segment_length,
"overlap": self.overlap, "separators": self.separators,
"join_symbol": self.join_symbol}

response = self.s.post(url=self.service_url(prefix=self.base_url, sub_path=""),
headers=headers, json=chunk_splitter_remote_params, stream=False)
self.check_response_header(response)
self.check_response_json(response.json())
response = self.http_client.session.post(url=self.http_client.service_url(prefix=self.base_url, sub_path=""),
headers=headers, json=chunk_splitter_remote_params, stream=False)
self.http_client.check_response_header(response)
self.http_client.check_response_json(response.json())
doc_chunk_splitter_res = response.json()

return Message(doc_chunk_splitter_res["result"])
Expand Down
16 changes: 7 additions & 9 deletions appbuilder/core/components/embeddings/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _check_response_json(self, data: dict):
check_response_json for embedding
"""

self.check_response_json(data)
self.http_client.check_response_json(data)
if "error_code" in data and "error_msg" in data:
raise AppBuilderServerException(
service_err_code=data['error_code'],
Expand All @@ -82,16 +82,14 @@ def _request(self, payload: dict) -> dict:
"""
request to gateway
"""

resp = self.s.post(
url=self.service_url(self.base_url),
headers={
"X-Appbuilder-Authorization": f"{self.secret_key}",
"Content-Type": "application/json",
},
headers = self.http_client.auth_header()
headers["Content-Type"] = "application/json"
resp = self.http_client.session.post(
url=self.http_client.service_url(self.base_url),
headers=headers,
json=payload,
)
self.check_response_header(resp)
self.http_client.check_response_header(resp)
self._check_response_json(resp.json())

return resp.json()
Expand Down
8 changes: 4 additions & 4 deletions appbuilder/core/components/extract_table/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ def run(self, message: Message, table_max_size: int = 800, doc_node_num_before_t
"single_table_size": self.table_max_size,
"field_before_table_cnt": doc_node_num_before_table
}
url = self.service_url(sub_path="", prefix=self.base_url)
url = self.http_client.service_url(sub_path="", prefix=self.base_url)
# logger.info("request url: {}, headers: {}".format(url, headers))
resp = self.s.post(url=url, data=json.dumps(params), headers=self.auth_header())
resp = self.http_client.session.post(url=url, data=json.dumps(params), headers=self.http_client.auth_header())

self.check_response_header(resp)
self.http_client.check_response_header(resp)
resp = resp.json()
self.check_response_json(resp)
self.http_client.check_response_json(resp)
resp = self._post_process(resp)
return Message(resp)
16 changes: 8 additions & 8 deletions appbuilder/core/components/general_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ def _recognize(self, request: GeneralOCRRequest, timeout: float = None,
if not request.image and not request.url and not request.pdf_file and not request.ofd_file:
raise ValueError("one of image or url or must pdf_file or ofd_file be set")
data = GeneralOCRRequest.to_dict(request)
if self.retry.total != retry:
self.retry.total = retry
headers = self.auth_header()
if self.http_client.retry.total != retry:
self.http_client.retry.total = retry
headers = self.http_client.auth_header()
headers['content-type'] = 'application/x-www-form-urlencoded'
url = self.service_url("/v1/bce/aip/ocr/v1/accurate_basic")
response = self.s.post(url, headers=headers, data=data, timeout=timeout)
super().check_response_header(response)
url = self.http_client.service_url("/v1/bce/aip/ocr/v1/accurate_basic")
response = self.http_client.session.post(url, headers=headers, data=data, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
super().check_response_json(data)
self.http_client.check_response_json(data)
self.__class__._check_service_error(data)
request_id = self.response_request_id(response)
request_id = self.http_client.response_request_id(response)
ocr_response = GeneralOCRResponse.from_json(payload=json.dumps(data))
ocr_response.request_id = request_id
return ocr_response
Expand Down
Loading

0 comments on commit c41b7d1

Please sign in to comment.