Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
devops committed Jun 27, 2024
2 parents bce3f74 + 67edd93 commit 343310f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 50 deletions.
131 changes: 82 additions & 49 deletions pyk/src/pyk/kore/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,52 @@ def __init__(self, message: str, code: int, data: Any = None):


class Transport(ContextManager['Transport'], ABC):
_bug_report: BugReport | None
_bug_report_id: str | None

def __init__(self, bug_report_id: str | None = None, bug_report: BugReport | None = None) -> None:
if (bug_report_id is None and bug_report is not None) or (bug_report_id is not None and bug_report is None):
raise ValueError('bug_report and bug_report_id must be passed together.')
self._bug_report_id = bug_report_id
self._bug_report = bug_report

def request(self, req: str, request_id: int, method_name: str) -> str:
base_name = self._bug_report_id if self._bug_report_id is not None else 'kore_rpc'
req_name = f'{base_name}/{id(self)}/{request_id:03}'
if self._bug_report:
bug_report_request = f'{req_name}_request.json'
self._bug_report.add_file_contents(req, Path(bug_report_request))
self._bug_report.add_command(self._command(req_name, bug_report_request))

server_addr = self._description()
_LOGGER.info(f'Sending request to {server_addr}: {request_id} - {method_name}')
_LOGGER.debug(f'Sending request to {server_addr}: {req}')
resp = self._request(req)
_LOGGER.info(f'Received response from {server_addr}: {request_id} - {method_name}')
_LOGGER.debug(f'Received response from {server_addr}: {resp}')

if self._bug_report:
bug_report_response = f'{req_name}_response.json'
self._bug_report.add_file_contents(resp, Path(bug_report_response))
self._bug_report.add_command(
[
'diff',
'-b',
'-s',
f'{req_name}_actual.json',
f'{req_name}_response.json',
]
)
return resp

@abstractmethod
def _command(self, req_name: str, bug_report_request: str) -> list[str]: ...

@abstractmethod
def _request(self, req: str) -> str: ...

@abstractmethod
def request(self, req: str) -> str: ...
def _description(self) -> str: ...

def __enter__(self) -> Transport:
return self
Expand All @@ -68,12 +112,6 @@ def __exit__(self, *args: Any) -> None:
@abstractmethod
def close(self) -> None: ...

@abstractmethod
def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> list[str]: ...

@abstractmethod
def description(self) -> str: ...


class TransportType(Enum):
SINGLE_SOCKET = auto()
Expand All @@ -87,7 +125,16 @@ class SingleSocketTransport(Transport):
_sock: socket.socket
_file: TextIO

def __init__(self, host: str, port: int, *, timeout: int | None = None):
def __init__(
self,
host: str,
port: int,
*,
timeout: int | None = None,
bug_report_id: str | None = None,
bug_report: BugReport | None = None,
):
super().__init__(bug_report_id, bug_report)
self._host = host
self._port = port
self._sock = self._create_connection(host, port, timeout)
Expand Down Expand Up @@ -117,7 +164,7 @@ def close(self) -> None:
self._file.close()
self._sock.close()

def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> list[str]:
def _command(self, req_name: str, bug_report_request: str) -> list[str]:
return [
'cat',
bug_report_request,
Expand All @@ -127,16 +174,16 @@ def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> l
self._host,
str(self._port),
'>',
f'rpc_{bug_report_id}/{old_id:03}_actual.json',
f'{req_name}_actual.json',
]

def request(self, req: str) -> str:
def _request(self, req: str) -> str:
self._sock.sendall(req.encode())
server_addr = self.description()
server_addr = self._description()
_LOGGER.debug(f'Waiting for response from {server_addr}...')
return self._file.readline().rstrip()

def description(self) -> str:
def _description(self) -> str:
return f'{self._host}:{self._port}'


Expand All @@ -146,15 +193,24 @@ class HttpTransport(Transport):
_port: int
_timeout: int | None

def __init__(self, host: str, port: int, *, timeout: int | None = None):
def __init__(
self,
host: str,
port: int,
*,
timeout: int | None = None,
bug_report_id: str | None = None,
bug_report: BugReport | None = None,
):
super().__init__(bug_report_id, bug_report)
self._host = host
self._port = port
self._timeout = timeout

def close(self) -> None:
pass

def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> list[str]:
def _command(self, req_name: str, bug_report_request: str) -> list[str]:
return [
'curl',
'-X',
Expand All @@ -165,20 +221,20 @@ def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> l
'@' + bug_report_request,
'http://' + self._host + ':' + str(self._port),
'>',
f'rpc_{bug_report_id}/{old_id:03}_actual.json',
f'{req_name}_actual.json',
]

def request(self, req: str) -> str:
def _request(self, req: str) -> str:
connection = http.client.HTTPConnection(self._host, self._port, timeout=self._timeout)
connection.request('POST', '/', body=req, headers={'Content-Type': 'application/json'})
server_addr = self.description()
server_addr = self._description()
_LOGGER.debug(f'Waiting for response from {server_addr}...')
response = connection.getresponse()
if response.status != 200:
raise JsonRpcError('Internal server error', -32603)
return response.read().decode()

def description(self) -> str:
def _description(self) -> str:
return f'{self._host}:{self._port}'


Expand Down Expand Up @@ -258,8 +314,6 @@ class JsonRpcClient(ContextManager['JsonRpcClient']):

_transport: Transport
_req_id: int
_bug_report: BugReport | None
_bug_report_id: str

def __init__(
self,
Expand All @@ -272,14 +326,16 @@ def __init__(
transport: TransportType = TransportType.SINGLE_SOCKET,
):
if transport is TransportType.SINGLE_SOCKET:
self._transport = SingleSocketTransport(host, port, timeout=timeout)
self._transport = SingleSocketTransport(
host, port, timeout=timeout, bug_report=bug_report, bug_report_id=bug_report_id
)
elif transport is TransportType.HTTP:
self._transport = HttpTransport(host, port, timeout=timeout)
self._transport = HttpTransport(
host, port, timeout=timeout, bug_report=bug_report, bug_report_id=bug_report_id
)
else:
raise AssertionError()
self._req_id = 1
self._bug_report = bug_report
self._bug_report_id = bug_report_id if bug_report_id is not None else str(id(self))

def __enter__(self) -> JsonRpcClient:
return self
Expand All @@ -301,38 +357,15 @@ def request(self, method: str, **params: Any) -> dict[str, Any]:
'params': params,
}

server_addr = self._transport.description()
_LOGGER.info(f'Sending request to {server_addr}: {old_id} - {method}')
req = json.dumps(payload)
if self._bug_report:
bug_report_request = f'rpc_{self._bug_report_id}/{old_id:03}_request.json'
self._bug_report.add_file_contents(req, Path(bug_report_request))
self._bug_report.add_command(self._transport.command(self._bug_report_id, old_id, bug_report_request))

_LOGGER.debug(f'Sending request to {server_addr}: {req}')
resp = self._transport.request(req)
resp = self._transport.request(req, old_id, method)
if not resp:
raise RuntimeError('Empty response received')
_LOGGER.debug(f'Received response from {server_addr}: {resp}')

if self._bug_report:
bug_report_response = f'rpc_{self._bug_report_id}/{old_id:03}_response.json'
self._bug_report.add_file_contents(resp, Path(bug_report_response))
self._bug_report.add_command(
[
'diff',
'-b',
'-s',
f'rpc_{self._bug_report_id}/{old_id:03}_actual.json',
f'rpc_{self._bug_report_id}/{old_id:03}_response.json',
]
)

data = json.loads(resp)
self._check(data)
assert data['id'] == old_id

_LOGGER.info(f'Received response from {server_addr}: {old_id} - {method}')
return data['result']

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion pyk/src/tests/unit/kore/test_rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def transport(mock: Mock) -> MockTransport:
@pytest.fixture
def kore_client(mock: Mock, mock_class: Mock) -> Iterator[KoreClient]: # noqa: N803
client = KoreClient('localhost', 3000)
mock_class.assert_called_with('localhost', 3000, timeout=None)
mock_class.assert_called_with('localhost', 3000, timeout=None, bug_report=None, bug_report_id=None)
assert client._client._default_client._transport == mock
yield client
client.close()
Expand Down

0 comments on commit 343310f

Please sign in to comment.