From f771d21a5048717d8c59799a45ebae179dd7caba Mon Sep 17 00:00:00 2001 From: lyam Date: Thu, 6 Jun 2024 13:48:10 +0200 Subject: [PATCH] Feature: VmClient --- pyproject.toml | 1 + src/aleph/sdk/chains/common.py | 7 - src/aleph/sdk/chains/ethereum.py | 8 +- src/aleph/sdk/chains/substrate.py | 3 +- src/aleph/sdk/client/vmclient.py | 161 +++++++++++++++++++++++ src/aleph/sdk/types.py | 2 + src/aleph/sdk/utils.py | 11 ++ src/aleph/sdk/wallets/ledger/ethereum.py | 3 +- 8 files changed, 181 insertions(+), 15 deletions(-) create mode 100644 src/aleph/sdk/client/vmclient.py diff --git a/pyproject.toml b/pyproject.toml index 8a70e9c8..6fb23849 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "coincurve>=19.0.0; python_version>=\"3.11\"", "eth_abi>=4.0.0; python_version>=\"3.11\"", "eth_account>=0.4.0,<0.11.0", + "jwcrypto==1.5.6", "python-magic", "typer", "typing_extensions", diff --git a/src/aleph/sdk/chains/common.py b/src/aleph/sdk/chains/common.py index b73d6e41..0a90183c 100644 --- a/src/aleph/sdk/chains/common.py +++ b/src/aleph/sdk/chains/common.py @@ -170,10 +170,3 @@ def get_fallback_private_key(path: Optional[Path] = None) -> bytes: if not default_key_path.exists(): default_key_path.symlink_to(path) return private_key - - -def bytes_from_hex(hex_string: str) -> bytes: - if hex_string.startswith("0x"): - hex_string = hex_string[2:] - hex_string = bytes.fromhex(hex_string) - return hex_string diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index 124fbee7..b0fa5fbe 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -7,12 +7,8 @@ from eth_keys.exceptions import BadSignature as EthBadSignatureError from ..exceptions import BadSignatureError -from .common import ( - BaseAccount, - bytes_from_hex, - get_fallback_private_key, - get_public_key, -) +from ..utils import bytes_from_hex +from .common import BaseAccount, get_fallback_private_key, get_public_key class ETHAccount(BaseAccount): diff --git a/src/aleph/sdk/chains/substrate.py b/src/aleph/sdk/chains/substrate.py index 13795568..f4d18a0d 100644 --- a/src/aleph/sdk/chains/substrate.py +++ b/src/aleph/sdk/chains/substrate.py @@ -9,7 +9,8 @@ from ..conf import settings from ..exceptions import BadSignatureError -from .common import BaseAccount, bytes_from_hex, get_verification_buffer +from ..utils import bytes_from_hex +from .common import BaseAccount, get_verification_buffer logger = logging.getLogger(__name__) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py new file mode 100644 index 00000000..1cf66639 --- /dev/null +++ b/src/aleph/sdk/client/vmclient.py @@ -0,0 +1,161 @@ +import datetime +import json +import logging +from typing import Any, Dict, Tuple + +import aiohttp +from eth_account.messages import encode_defunct +from jwcrypto import jwk +from jwcrypto.jwa import JWA + +from aleph.sdk.types import Account +from aleph.sdk.utils import to_0x_hex + +logger = logging.getLogger(__name__) + + +class VmClient: + def __init__(self, account: Account, domain: str = ""): + self.account: Account = account + self.ephemeral_key: jwk.JWK = jwk.JWK.generate(kty="EC", crv="P-256") + self.domain: str = domain + self.pubkey_payload = self._generate_pubkey_payload() + self.pubkey_signature_header: str = "" + self.session = aiohttp.ClientSession() + + def _generate_pubkey_payload(self) -> Dict[str, Any]: + return { + "pubkey": json.loads(self.ephemeral_key.export_public()), + "alg": "ECDSA", + "domain": self.domain, + "address": self.account.get_address(), + "expires": ( + datetime.datetime.utcnow() + datetime.timedelta(days=1) + ).isoformat() + + "Z", + } + + async def _generate_pubkey_signature_header(self) -> str: + pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex() + signable_message = encode_defunct(hexstr=pubkey_payload) + buffer_to_sign = signable_message.body + + signed_message = await self.account.sign_raw(buffer_to_sign) + pubkey_signature = to_0x_hex(signed_message) + + return json.dumps( + { + "sender": self.account.get_address(), + "payload": pubkey_payload, + "signature": pubkey_signature, + "content": {"domain": self.domain}, + } + ) + + async def _generate_header( + self, vm_id: str, operation: str + ) -> Tuple[str, Dict[str, str]]: + base_url = f"http://{self.domain}" + path = ( + f"/logs/{vm_id}" + if operation == "logs" + else f"/control/machine/{vm_id}/{operation}" + ) + + payload = { + "time": datetime.datetime.utcnow().isoformat() + "Z", + "method": "POST", + "path": path, + } + payload_as_bytes = json.dumps(payload).encode("utf-8") + headers = {"X-SignedPubKey": self.pubkey_signature_header} + payload_signature = JWA.signing_alg("ES256").sign( + self.ephemeral_key, payload_as_bytes + ) + headers["X-SignedOperation"] = json.dumps( + { + "payload": payload_as_bytes.hex(), + "signature": payload_signature.hex(), + } + ) + + return f"{base_url}{path}", headers + + async def perform_operation(self, vm_id, operation): + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header(vm_id=vm_id, operation=operation) + + try: + async with self.session.post(url, headers=header) as response: + response_text = await response.text() + return response.status, response_text + except aiohttp.ClientError as e: + logger.error(f"HTTP error during operation {operation}: {str(e)}") + return None, str(e) + + async def get_logs(self, vm_id): + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + ws_url, header = await self._generate_header(vm_id=vm_id, operation="logs") + + async with aiohttp.ClientSession() as session: + async with session.ws_connect(ws_url) as ws: + auth_message = { + "auth": { + "X-SignedPubKey": header["X-SignedPubKey"], + "X-SignedOperation": header["X-SignedOperation"], + } + } + await ws.send_json(auth_message) + async for msg in ws: # msg is of type aiohttp.WSMessage + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + async def start_instance(self, vm_id): + return await self.notify_allocation(vm_id) + + async def stop_instance(self, vm_id): + return await self.perform_operation(vm_id, "stop") + + async def reboot_instance(self, vm_id): + + return await self.perform_operation(vm_id, "reboot") + + async def erase_instance(self, vm_id): + return await self.perform_operation(vm_id, "erase") + + async def expire_instance(self, vm_id): + return await self.perform_operation(vm_id, "expire") + + async def notify_allocation(self, vm_id) -> Tuple[Any, str]: + json_data = {"instance": vm_id} + async with self.session.post( + f"https://{self.domain}/control/allocation/notify", json=json_data + ) as s: + form_response_text = await s.text() + return s.status, form_response_text + + async def manage_instance(self, vm_id, operations): + for operation in operations: + status, response = await self.perform_operation(vm_id, operation) + if status != 200: + return status, response + return + + async def close(self): + await self.session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 8d17f4d4..d344123d 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -20,6 +20,8 @@ class Account(Protocol): @abstractmethod async def sign_message(self, message: Dict) -> Dict: ... + @abstractmethod + async def sign_raw(self, buffer: bytes) -> bytes: ... @abstractmethod def get_address(self) -> str: ... diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index b1c04cdf..ffd7ac6b 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -184,3 +184,14 @@ def parse_volume(volume_dict: Union[Mapping, MachineVolume]) -> MachineVolume: def compute_sha256(s: str) -> str: """Compute the SHA256 hash of a string.""" return hashlib.sha256(s.encode()).hexdigest() + + +def to_0x_hex(b: bytes) -> str: + return "0x" + bytes.hex(b) + + +def bytes_from_hex(hex_string: str) -> bytes: + if hex_string.startswith("0x"): + hex_string = hex_string[2:] + hex_string = bytes.fromhex(hex_string) + return hex_string diff --git a/src/aleph/sdk/wallets/ledger/ethereum.py b/src/aleph/sdk/wallets/ledger/ethereum.py index 2ecdc5d3..5dc40f03 100644 --- a/src/aleph/sdk/wallets/ledger/ethereum.py +++ b/src/aleph/sdk/wallets/ledger/ethereum.py @@ -9,7 +9,8 @@ from ledgereth.messages import sign_message from ledgereth.objects import LedgerAccount, SignedMessage -from ...chains.common import BaseAccount, bytes_from_hex, get_verification_buffer +from ...chains.common import BaseAccount, get_verification_buffer +from ...utils import bytes_from_hex class LedgerETHAccount(BaseAccount):