From f771d21a5048717d8c59799a45ebae179dd7caba Mon Sep 17 00:00:00 2001 From: lyam Date: Thu, 6 Jun 2024 13:48:10 +0200 Subject: [PATCH 01/34] Feature: VmClient --- pyproject.toml | 1 + src/aleph/sdk/chains/common.py | 7 - src/aleph/sdk/chains/ethereum.py | 8 +- src/aleph/sdk/chains/substrate.py | 3 +- src/aleph/sdk/client/vmclient.py | 161 +++++++++++++++++++++++ src/aleph/sdk/types.py | 2 + src/aleph/sdk/utils.py | 11 ++ src/aleph/sdk/wallets/ledger/ethereum.py | 3 +- 8 files changed, 181 insertions(+), 15 deletions(-) create mode 100644 src/aleph/sdk/client/vmclient.py diff --git a/pyproject.toml b/pyproject.toml index 8a70e9c8..6fb23849 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "coincurve>=19.0.0; python_version>=\"3.11\"", "eth_abi>=4.0.0; python_version>=\"3.11\"", "eth_account>=0.4.0,<0.11.0", + "jwcrypto==1.5.6", "python-magic", "typer", "typing_extensions", diff --git a/src/aleph/sdk/chains/common.py b/src/aleph/sdk/chains/common.py index b73d6e41..0a90183c 100644 --- a/src/aleph/sdk/chains/common.py +++ b/src/aleph/sdk/chains/common.py @@ -170,10 +170,3 @@ def get_fallback_private_key(path: Optional[Path] = None) -> bytes: if not default_key_path.exists(): default_key_path.symlink_to(path) return private_key - - -def bytes_from_hex(hex_string: str) -> bytes: - if hex_string.startswith("0x"): - hex_string = hex_string[2:] - hex_string = bytes.fromhex(hex_string) - return hex_string diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index 124fbee7..b0fa5fbe 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -7,12 +7,8 @@ from eth_keys.exceptions import BadSignature as EthBadSignatureError from ..exceptions import BadSignatureError -from .common import ( - BaseAccount, - bytes_from_hex, - get_fallback_private_key, - get_public_key, -) +from ..utils import bytes_from_hex +from .common import BaseAccount, get_fallback_private_key, get_public_key class ETHAccount(BaseAccount): diff --git a/src/aleph/sdk/chains/substrate.py b/src/aleph/sdk/chains/substrate.py index 13795568..f4d18a0d 100644 --- a/src/aleph/sdk/chains/substrate.py +++ b/src/aleph/sdk/chains/substrate.py @@ -9,7 +9,8 @@ from ..conf import settings from ..exceptions import BadSignatureError -from .common import BaseAccount, bytes_from_hex, get_verification_buffer +from ..utils import bytes_from_hex +from .common import BaseAccount, get_verification_buffer logger = logging.getLogger(__name__) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py new file mode 100644 index 00000000..1cf66639 --- /dev/null +++ b/src/aleph/sdk/client/vmclient.py @@ -0,0 +1,161 @@ +import datetime +import json +import logging +from typing import Any, Dict, Tuple + +import aiohttp +from eth_account.messages import encode_defunct +from jwcrypto import jwk +from jwcrypto.jwa import JWA + +from aleph.sdk.types import Account +from aleph.sdk.utils import to_0x_hex + +logger = logging.getLogger(__name__) + + +class VmClient: + def __init__(self, account: Account, domain: str = ""): + self.account: Account = account + self.ephemeral_key: jwk.JWK = jwk.JWK.generate(kty="EC", crv="P-256") + self.domain: str = domain + self.pubkey_payload = self._generate_pubkey_payload() + self.pubkey_signature_header: str = "" + self.session = aiohttp.ClientSession() + + def _generate_pubkey_payload(self) -> Dict[str, Any]: + return { + "pubkey": json.loads(self.ephemeral_key.export_public()), + "alg": "ECDSA", + "domain": self.domain, + "address": self.account.get_address(), + "expires": ( + datetime.datetime.utcnow() + datetime.timedelta(days=1) + ).isoformat() + + "Z", + } + + async def _generate_pubkey_signature_header(self) -> str: + pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex() + signable_message = encode_defunct(hexstr=pubkey_payload) + buffer_to_sign = signable_message.body + + signed_message = await self.account.sign_raw(buffer_to_sign) + pubkey_signature = to_0x_hex(signed_message) + + return json.dumps( + { + "sender": self.account.get_address(), + "payload": pubkey_payload, + "signature": pubkey_signature, + "content": {"domain": self.domain}, + } + ) + + async def _generate_header( + self, vm_id: str, operation: str + ) -> Tuple[str, Dict[str, str]]: + base_url = f"http://{self.domain}" + path = ( + f"/logs/{vm_id}" + if operation == "logs" + else f"/control/machine/{vm_id}/{operation}" + ) + + payload = { + "time": datetime.datetime.utcnow().isoformat() + "Z", + "method": "POST", + "path": path, + } + payload_as_bytes = json.dumps(payload).encode("utf-8") + headers = {"X-SignedPubKey": self.pubkey_signature_header} + payload_signature = JWA.signing_alg("ES256").sign( + self.ephemeral_key, payload_as_bytes + ) + headers["X-SignedOperation"] = json.dumps( + { + "payload": payload_as_bytes.hex(), + "signature": payload_signature.hex(), + } + ) + + return f"{base_url}{path}", headers + + async def perform_operation(self, vm_id, operation): + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header(vm_id=vm_id, operation=operation) + + try: + async with self.session.post(url, headers=header) as response: + response_text = await response.text() + return response.status, response_text + except aiohttp.ClientError as e: + logger.error(f"HTTP error during operation {operation}: {str(e)}") + return None, str(e) + + async def get_logs(self, vm_id): + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + ws_url, header = await self._generate_header(vm_id=vm_id, operation="logs") + + async with aiohttp.ClientSession() as session: + async with session.ws_connect(ws_url) as ws: + auth_message = { + "auth": { + "X-SignedPubKey": header["X-SignedPubKey"], + "X-SignedOperation": header["X-SignedOperation"], + } + } + await ws.send_json(auth_message) + async for msg in ws: # msg is of type aiohttp.WSMessage + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + async def start_instance(self, vm_id): + return await self.notify_allocation(vm_id) + + async def stop_instance(self, vm_id): + return await self.perform_operation(vm_id, "stop") + + async def reboot_instance(self, vm_id): + + return await self.perform_operation(vm_id, "reboot") + + async def erase_instance(self, vm_id): + return await self.perform_operation(vm_id, "erase") + + async def expire_instance(self, vm_id): + return await self.perform_operation(vm_id, "expire") + + async def notify_allocation(self, vm_id) -> Tuple[Any, str]: + json_data = {"instance": vm_id} + async with self.session.post( + f"https://{self.domain}/control/allocation/notify", json=json_data + ) as s: + form_response_text = await s.text() + return s.status, form_response_text + + async def manage_instance(self, vm_id, operations): + for operation in operations: + status, response = await self.perform_operation(vm_id, operation) + if status != 200: + return status, response + return + + async def close(self): + await self.session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 8d17f4d4..d344123d 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -20,6 +20,8 @@ class Account(Protocol): @abstractmethod async def sign_message(self, message: Dict) -> Dict: ... + @abstractmethod + async def sign_raw(self, buffer: bytes) -> bytes: ... @abstractmethod def get_address(self) -> str: ... diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index b1c04cdf..ffd7ac6b 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -184,3 +184,14 @@ def parse_volume(volume_dict: Union[Mapping, MachineVolume]) -> MachineVolume: def compute_sha256(s: str) -> str: """Compute the SHA256 hash of a string.""" return hashlib.sha256(s.encode()).hexdigest() + + +def to_0x_hex(b: bytes) -> str: + return "0x" + bytes.hex(b) + + +def bytes_from_hex(hex_string: str) -> bytes: + if hex_string.startswith("0x"): + hex_string = hex_string[2:] + hex_string = bytes.fromhex(hex_string) + return hex_string diff --git a/src/aleph/sdk/wallets/ledger/ethereum.py b/src/aleph/sdk/wallets/ledger/ethereum.py index 2ecdc5d3..5dc40f03 100644 --- a/src/aleph/sdk/wallets/ledger/ethereum.py +++ b/src/aleph/sdk/wallets/ledger/ethereum.py @@ -9,7 +9,8 @@ from ledgereth.messages import sign_message from ledgereth.objects import LedgerAccount, SignedMessage -from ...chains.common import BaseAccount, bytes_from_hex, get_verification_buffer +from ...chains.common import BaseAccount, get_verification_buffer +from ...utils import bytes_from_hex class LedgerETHAccount(BaseAccount): From ac78d5367acb6ee3436f9b801ab1ca48ad293805 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Wed, 19 Jun 2024 15:46:58 +0200 Subject: [PATCH 02/34] Fix: Protocol (http/https) should not be hardcoded. One place hardcoded `http://`, the other one `https://`. --- src/aleph/sdk/client/vmclient.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 1cf66639..8aedb1ec 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -1,7 +1,7 @@ import datetime import json import logging -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, Optional import aiohttp from eth_account.messages import encode_defunct @@ -15,19 +15,31 @@ class VmClient: - def __init__(self, account: Account, domain: str = ""): + account: Account + ephemeral_key: jwk.JWK + node_url: str + pubkey_payload: Dict[str, Any] + pubkey_signature_header: str + session: aiohttp.ClientSession + + def __init__( + self, + account: Account, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): self.account: Account = account self.ephemeral_key: jwk.JWK = jwk.JWK.generate(kty="EC", crv="P-256") - self.domain: str = domain + self.node_url: str = node_url self.pubkey_payload = self._generate_pubkey_payload() self.pubkey_signature_header: str = "" - self.session = aiohttp.ClientSession() + self.session = session or aiohttp.ClientSession() def _generate_pubkey_payload(self) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", - "domain": self.domain, + "domain": self.node_url, "address": self.account.get_address(), "expires": ( datetime.datetime.utcnow() + datetime.timedelta(days=1) @@ -48,14 +60,13 @@ async def _generate_pubkey_signature_header(self) -> str: "sender": self.account.get_address(), "payload": pubkey_payload, "signature": pubkey_signature, - "content": {"domain": self.domain}, + "content": {"domain": self.node_url}, } ) async def _generate_header( self, vm_id: str, operation: str ) -> Tuple[str, Dict[str, str]]: - base_url = f"http://{self.domain}" path = ( f"/logs/{vm_id}" if operation == "logs" @@ -79,7 +90,7 @@ async def _generate_header( } ) - return f"{base_url}{path}", headers + return f"{self.node_url}{path}", headers async def perform_operation(self, vm_id, operation): if not self.pubkey_signature_header: @@ -139,7 +150,7 @@ async def expire_instance(self, vm_id): async def notify_allocation(self, vm_id) -> Tuple[Any, str]: json_data = {"instance": vm_id} async with self.session.post( - f"https://{self.domain}/control/allocation/notify", json=json_data + f"{self.node_url}/control/allocation/notify", json=json_data ) as s: form_response_text = await s.text() return s.status, form_response_text From 5672776ea0d71dfd626ffbf58d291c7707e5fccd Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Wed, 19 Jun 2024 15:47:31 +0200 Subject: [PATCH 03/34] Fix: There was no test for `notify_allocation()`. --- pyproject.toml | 2 ++ tests/unit/test_vmclient.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/unit/test_vmclient.py diff --git a/pyproject.toml b/pyproject.toml index 6fb23849..81962db8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,8 @@ dependencies = [ "pytest-cov==4.1.0", "pytest-mock==3.12.0", "pytest-asyncio==0.23.5", + "jwcrypto==1.5.6", + "aioresponses==0.7.6", "fastapi", "httpx", "secp256k1", diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py new file mode 100644 index 00000000..e8168c6a --- /dev/null +++ b/tests/unit/test_vmclient.py @@ -0,0 +1,23 @@ +import aiohttp +import pytest +from aioresponses import aioresponses +from aleph_message.models import ItemHash + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vmclient import VmClient + + +@pytest.mark.asyncio +async def test_notify_allocation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post("http://localhost/control/allocation/notify", status=200) + await vm_client.notify_allocation(vm_id=vm_id) + assert m.requests From 301814a218850807863f40483995fd654b76ae55 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Wed, 19 Jun 2024 15:47:49 +0200 Subject: [PATCH 04/34] WIP: Copy authentication functions from aleph-vm --- tests/unit/aleph_vm_authentication.py | 260 ++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 tests/unit/aleph_vm_authentication.py diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py new file mode 100644 index 00000000..b4448911 --- /dev/null +++ b/tests/unit/aleph_vm_authentication.py @@ -0,0 +1,260 @@ +# Keep datetime import as is as it allow patching in test +import datetime +import functools +import json +import logging +from collections.abc import Awaitable, Coroutine +from typing import Any, Callable, Literal, Union + +import cryptography.exceptions +import pydantic +from aiohttp import web +from eth_account import Account +from eth_account.messages import encode_defunct +from jwcrypto import jwk +from jwcrypto.jwa import JWA +from pydantic import BaseModel, ValidationError, root_validator, validator + + +logger = logging.getLogger(__name__) + + +def is_token_still_valid(datestr: str): + """ + Checks if a token has expired based on its expiry timestamp + """ + current_datetime = datetime.datetime.now(tz=datetime.timezone.utc) + expiry_datetime = datetime.datetime.fromisoformat(datestr.replace("Z", "+00:00")) + + return expiry_datetime > current_datetime + + +def verify_wallet_signature(signature, message, address): + """ + Verifies a signature issued by a wallet + """ + enc_msg = encode_defunct(hexstr=message) + computed_address = Account.recover_message(enc_msg, signature=signature) + return computed_address.lower() == address.lower() + + +class SignedPubKeyPayload(BaseModel): + """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" + + pubkey: dict[str, Any] + # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', + # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} + # alg: Literal["ECDSA"] + domain: str + address: str + expires: str + + @property + def json_web_key(self) -> jwk.JWK: + """Return the ephemeral public key as Json Web Key""" + return jwk.JWK(**self.pubkey) + + +class SignedPubKeyHeader(BaseModel): + signature: bytes + payload: bytes + + @validator("signature") + def signature_must_be_hex(cls, v: bytes) -> bytes: + """Convert the signature from hexadecimal to bytes""" + return bytes.fromhex(v.removeprefix(b"0x").decode()) + + @validator("payload") + def payload_must_be_hex(cls, v: bytes) -> bytes: + """Convert the payload from hexadecimal to bytes""" + return bytes.fromhex(v.decode()) + + @root_validator(pre=False, skip_on_failure=True) + def check_expiry(cls, values) -> dict[str, bytes]: + """Check that the token has not expired""" + payload: bytes = values["payload"] + content = SignedPubKeyPayload.parse_raw(payload) + if not is_token_still_valid(content.expires): + msg = "Token expired" + raise ValueError(msg) + return values + + @root_validator(pre=False, skip_on_failure=True) + def check_signature(cls, values) -> dict[str, bytes]: + """Check that the signature is valid""" + signature: bytes = values["signature"] + payload: bytes = values["payload"] + content = SignedPubKeyPayload.parse_raw(payload) + if not verify_wallet_signature(signature, payload.hex(), content.address): + msg = "Invalid signature" + raise ValueError(msg) + return values + + @property + def content(self) -> SignedPubKeyPayload: + """Return the content of the header""" + return SignedPubKeyPayload.parse_raw(self.payload) + + +class SignedOperationPayload(BaseModel): + time: datetime.datetime + method: Union[Literal["POST"], Literal["GET"]] + path: str + # body_sha256: str # disabled since there is no body + + @validator("time") + def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: + """Check that the time is current and the payload is not a replay attack.""" + max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( + minutes=2 + ) + max_future = datetime.datetime.now( + tz=datetime.timezone.utc + ) + datetime.timedelta(minutes=2) + if v < max_past: + raise ValueError("Time is too far in the past") + if v > max_future: + raise ValueError("Time is too far in the future") + return v + + +class SignedOperation(BaseModel): + """This payload is signed by the ephemeral key authorized above.""" + + signature: bytes + payload: bytes + + @validator("signature") + def signature_must_be_hex(cls, v) -> bytes: + """Convert the signature from hexadecimal to bytes""" + try: + return bytes.fromhex(v.removeprefix(b"0x").decode()) + except pydantic.ValidationError as error: + print(v) + logger.warning(v) + raise error + + @validator("payload") + def payload_must_be_hex(cls, v) -> bytes: + """Convert the payload from hexadecimal to bytes""" + v = bytes.fromhex(v.decode()) + _ = SignedOperationPayload.parse_raw(v) + return v + + @property + def content(self) -> SignedOperationPayload: + """Return the content of the header""" + return SignedOperationPayload.parse_raw(self.payload) + + +def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: + """Get the ephemeral public key that is signed by the wallet from the request headers.""" + signed_pubkey_header = request.headers.get("X-SignedPubKey") + if not signed_pubkey_header: + raise web.HTTPBadRequest(reason="Missing X-SignedPubKey header") + + try: + return SignedPubKeyHeader.parse_raw(signed_pubkey_header) + except KeyError as error: + logger.debug(f"Missing X-SignedPubKey header: {error}") + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey fields") from error + except json.JSONDecodeError as error: + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from error + except ValueError as errors: + logging.debug(errors) + for err in errors.args[0]: + if isinstance(err.exc, json.JSONDecodeError): + raise web.HTTPBadRequest( + reason="Invalid X-SignedPubKey format" + ) from errors + if str(err.exc) == "Token expired": + raise web.HTTPUnauthorized(reason="Token expired") from errors + if str(err.exc) == "Invalid signature": + raise web.HTTPUnauthorized(reason="Invalid signature") from errors + else: + raise errors + + +def get_signed_operation(request: web.Request) -> SignedOperation: + """Get the signed operation public key that is signed by the ephemeral key from the request headers.""" + try: + signed_operation = request.headers["X-SignedOperation"] + return SignedOperation.parse_raw(signed_operation) + except KeyError as error: + raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error + except json.JSONDecodeError as error: + raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error + except ValidationError as error: + logger.debug(f"Invalid X-SignedOperation fields: {error}") + raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error + + +def verify_signed_operation( + signed_operation: SignedOperation, signed_pubkey: SignedPubKeyHeader +) -> str: + """Verify that the operation is signed by the ephemeral key authorized by the wallet.""" + pubkey = signed_pubkey.content.json_web_key + + try: + JWA.signing_alg("ES256").verify( + pubkey, signed_operation.payload, signed_operation.signature + ) + logger.debug("Signature verified") + return signed_pubkey.content.address + except cryptography.exceptions.InvalidSignature as e: + logger.debug("Failing to validate signature for operation", e) + raise web.HTTPUnauthorized(reason="Signature could not verified") + + +async def authenticate_jwk(request: web.Request) -> str: + """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" + signed_pubkey = get_signed_pubkey(request) + signed_operation = get_signed_operation(request) + if signed_pubkey.content.domain != settings.DOMAIN_NAME: + logger.debug( + f"Invalid domain '{signed_pubkey.content.domain}' != '{settings.DOMAIN_NAME}'" + ) + raise web.HTTPUnauthorized(reason="Invalid domain") + if signed_operation.content.path != request.path: + logger.debug( + f"Invalid path '{signed_operation.content.path}' != '{request.path}'" + ) + raise web.HTTPUnauthorized(reason="Invalid path") + if signed_operation.content.method != request.method: + logger.debug( + f"Invalid method '{signed_operation.content.method}' != '{request.method}'" + ) + raise web.HTTPUnauthorized(reason="Invalid method") + return verify_signed_operation(signed_operation, signed_pubkey) + + +async def authenticate_websocket_message(message) -> str: + """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" + signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) + signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) + if signed_pubkey.content.node_url != settings.DOMAIN_NAME: + logger.debug( + f"Invalid domain '{signed_pubkey.content.node_url}' != '{settings.DOMAIN_NAME}'" + ) + raise web.HTTPUnauthorized(reason="Invalid domain") + return verify_signed_operation(signed_operation, signed_pubkey) + + +def require_jwk_authentication( + handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] +) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: + @functools.wraps(handler) + async def wrapper(request): + try: + authenticated_sender: str = await authenticate_jwk(request) + except web.HTTPException as e: + return web.json_response(data={"error": e.reason}, status=e.status) + except Exception as e: + # Unexpected make sure to log it + logging.exception(e) + raise + + response = await handler(request, authenticated_sender) + return response + + return wrapper From 9abb642ee2543ea45767ef034e89462e16eac045 Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 16:55:34 +0200 Subject: [PATCH 05/34] Fix: vm client sessions wasn't close + authentifications for test will use localhost as domain --- tests/unit/aleph_vm_authentication.py | 9 +++++---- tests/unit/test_vmclient.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index b4448911..c03ce211 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) +DOMAIN_NAME = "localhost" def is_token_still_valid(datestr: str): """ @@ -210,9 +211,9 @@ async def authenticate_jwk(request: web.Request) -> str: """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) - if signed_pubkey.content.domain != settings.DOMAIN_NAME: + if signed_pubkey.content.domain != DOMAIN_NAME: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{settings.DOMAIN_NAME}'" + f"Invalid domain '{signed_pubkey.content.domain}' != '{DOMAIN_NAME}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") if signed_operation.content.path != request.path: @@ -232,9 +233,9 @@ async def authenticate_websocket_message(message) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) - if signed_pubkey.content.node_url != settings.DOMAIN_NAME: + if signed_pubkey.content.node_url != DOMAIN_NAME: logger.debug( - f"Invalid domain '{signed_pubkey.content.node_url}' != '{settings.DOMAIN_NAME}'" + f"Invalid domain '{signed_pubkey.content.node_url}' != '{DOMAIN_NAME}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") return verify_signed_operation(signed_operation, signed_pubkey) diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index e8168c6a..85c91074 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -21,3 +21,4 @@ async def test_notify_allocation(): m.post("http://localhost/control/allocation/notify", status=200) await vm_client.notify_allocation(vm_id=vm_id) assert m.requests + await vm_client.session.close() From 68600157d8962319cf35c53d3f666e13ff05fa88 Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 18:00:41 +0200 Subject: [PATCH 06/34] Add: Unit test for {perform_operation, stop, reboot, erase, expire} --- tests/unit/test_vmclient.py | 91 +++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index 85c91074..64da9979 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -22,3 +22,94 @@ async def test_notify_allocation(): await vm_client.notify_allocation(vm_id=vm_id) assert m.requests await vm_client.session.close() + +@pytest.mark.asyncio +async def test_perform_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "reboot" + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post(f"http://localhost/control/machine/{vm_id}/{operation}", status=200, payload="mock_response_text") + + status, response_text = await vm_client.perform_operation(vm_id, operation) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + +@pytest.mark.asyncio +async def test_stop_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post(f"http://localhost/control/machine/{vm_id}/stop", status=200, payload="mock_response_text") + + status, response_text = await vm_client.stop_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + +@pytest.mark.asyncio +async def test_reboot_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post(f"http://localhost/control/machine/{vm_id}/reboot", status=200, payload="mock_response_text") + + status, response_text = await vm_client.reboot_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + +@pytest.mark.asyncio +async def test_erase_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post(f"http://localhost/control/machine/{vm_id}/erase", status=200, payload="mock_response_text") + + status, response_text = await vm_client.erase_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + +@pytest.mark.asyncio +async def test_expire_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post(f"http://localhost/control/machine/{vm_id}/expire", status=200, payload="mock_response_text") + + status, response_text = await vm_client.expire_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() \ No newline at end of file From 52faf6ee2694bd56a3ce7e7b6930e06d4cb40cbf Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 18:18:07 +0200 Subject: [PATCH 07/34] Refactor: logs didn't need to generate full header Fix: extracts domain from node url instead of sending url Fix: using vmclient sessions in get_logs instead of creating new one --- src/aleph/sdk/client/vmclient.py | 64 ++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 8aedb1ec..e5ff0708 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -1,7 +1,8 @@ import datetime import json import logging -from typing import Any, Dict, Tuple, Optional +from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlparse import aiohttp from eth_account.messages import encode_defunct @@ -39,7 +40,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", - "domain": self.node_url, + "domain": urlparse(self.node_url).netloc, "address": self.account.get_address(), "expires": ( datetime.datetime.utcnow() + datetime.timedelta(days=1) @@ -60,36 +61,51 @@ async def _generate_pubkey_signature_header(self) -> str: "sender": self.account.get_address(), "payload": pubkey_payload, "signature": pubkey_signature, - "content": {"domain": self.node_url}, + "content": {"domain": urlparse(self.node_url).netloc}, } ) - async def _generate_header( - self, vm_id: str, operation: str - ) -> Tuple[str, Dict[str, str]]: + def create_payload(self, vm_id: str, operation: str) -> Dict[str, str]: path = ( f"/logs/{vm_id}" if operation == "logs" else f"/control/machine/{vm_id}/{operation}" ) - payload = { "time": datetime.datetime.utcnow().isoformat() + "Z", "method": "POST", "path": path, } + return payload + + def sign_payload(self, payload: Dict[str, str], ephemeral_key) -> str: payload_as_bytes = json.dumps(payload).encode("utf-8") - headers = {"X-SignedPubKey": self.pubkey_signature_header} payload_signature = JWA.signing_alg("ES256").sign( - self.ephemeral_key, payload_as_bytes + ephemeral_key, payload_as_bytes ) - headers["X-SignedOperation"] = json.dumps( + signed_operation = json.dumps( { "payload": payload_as_bytes.hex(), "signature": payload_signature.hex(), } ) + return signed_operation + async def _generate_header( + self, vm_id: str, operation: str + ) -> Tuple[str, Dict[str, str]]: + payload = self.create_payload(vm_id, operation) + signed_operation = self.sign_payload(payload, self.ephemeral_key) + + if not self.pubkey_signature_header: + self.pubkey_signature_header = await self.generate_pubkey_signature_header() + + headers = { + "X-SignedPubKey": self.pubkey_signature_header, + "X-SignedOperation": signed_operation, + } + + path = payload["path"] return f"{self.node_url}{path}", headers async def perform_operation(self, vm_id, operation): @@ -114,22 +130,24 @@ async def get_logs(self, vm_id): await self._generate_pubkey_signature_header() ) + payload = self.create_payload(vm_id, "logs") + signed_operation = self.sign_payload(payload, self.ephemeral_key) + ws_url, header = await self._generate_header(vm_id=vm_id, operation="logs") - async with aiohttp.ClientSession() as session: - async with session.ws_connect(ws_url) as ws: - auth_message = { - "auth": { - "X-SignedPubKey": header["X-SignedPubKey"], - "X-SignedOperation": header["X-SignedOperation"], - } + async with self.session.ws_connect(ws_url) as ws: + auth_message = { + "auth": { + "X-SignedPubKey": self.pubkey_signature_header, + "X-SignedOperation": signed_operation, } - await ws.send_json(auth_message) - async for msg in ws: # msg is of type aiohttp.WSMessage - if msg.type == aiohttp.WSMsgType.TEXT: - yield msg.data - elif msg.type == aiohttp.WSMsgType.ERROR: - break + } + await ws.send_json(auth_message) + async for msg in ws: # msg is of type aiohttp.WSMessage + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.ERROR: + break async def start_instance(self, vm_id): return await self.notify_allocation(vm_id) From 328e08780e53d4bf5826f5adcf78c859d865e590 Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 19:14:48 +0200 Subject: [PATCH 08/34] Add: get_logs test --- tests/unit/test_vmclient.py | 86 ++++++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index 64da9979..8733954d 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -1,5 +1,6 @@ import aiohttp import pytest +from aiohttp import web from aioresponses import aioresponses from aleph_message.models import ItemHash @@ -23,6 +24,7 @@ async def test_notify_allocation(): assert m.requests await vm_client.session.close() + @pytest.mark.asyncio async def test_perform_operation(): account = ETHAccount(private_key=b"0x" + b"1" * 30) @@ -35,13 +37,18 @@ async def test_perform_operation(): node_url="http://localhost", session=aiohttp.ClientSession(), ) - m.post(f"http://localhost/control/machine/{vm_id}/{operation}", status=200, payload="mock_response_text") + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) status, response_text = await vm_client.perform_operation(vm_id, operation) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses await vm_client.session.close() + @pytest.mark.asyncio async def test_stop_instance(): account = ETHAccount(private_key=b"0x" + b"1" * 30) @@ -53,13 +60,18 @@ async def test_stop_instance(): node_url="http://localhost", session=aiohttp.ClientSession(), ) - m.post(f"http://localhost/control/machine/{vm_id}/stop", status=200, payload="mock_response_text") + m.post( + f"http://localhost/control/machine/{vm_id}/stop", + status=200, + payload="mock_response_text", + ) status, response_text = await vm_client.stop_instance(vm_id) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses await vm_client.session.close() + @pytest.mark.asyncio async def test_reboot_instance(): account = ETHAccount(private_key=b"0x" + b"1" * 30) @@ -71,13 +83,18 @@ async def test_reboot_instance(): node_url="http://localhost", session=aiohttp.ClientSession(), ) - m.post(f"http://localhost/control/machine/{vm_id}/reboot", status=200, payload="mock_response_text") + m.post( + f"http://localhost/control/machine/{vm_id}/reboot", + status=200, + payload="mock_response_text", + ) status, response_text = await vm_client.reboot_instance(vm_id) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses await vm_client.session.close() + @pytest.mark.asyncio async def test_erase_instance(): account = ETHAccount(private_key=b"0x" + b"1" * 30) @@ -89,13 +106,18 @@ async def test_erase_instance(): node_url="http://localhost", session=aiohttp.ClientSession(), ) - m.post(f"http://localhost/control/machine/{vm_id}/erase", status=200, payload="mock_response_text") + m.post( + f"http://localhost/control/machine/{vm_id}/erase", + status=200, + payload="mock_response_text", + ) status, response_text = await vm_client.erase_instance(vm_id) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses await vm_client.session.close() + @pytest.mark.asyncio async def test_expire_instance(): account = ETHAccount(private_key=b"0x" + b"1" * 30) @@ -107,9 +129,51 @@ async def test_expire_instance(): node_url="http://localhost", session=aiohttp.ClientSession(), ) - m.post(f"http://localhost/control/machine/{vm_id}/expire", status=200, payload="mock_response_text") + m.post( + f"http://localhost/control/machine/{vm_id}/expire", + status=200, + payload="mock_response_text", + ) status, response_text = await vm_client.expire_instance(vm_id) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses - await vm_client.session.close() \ No newline at end of file + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_get_logs(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await ws.send_str("mock_log_entry") + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + return ws + + app = web.Application() + app.router.add_route("GET", "/logs/{vm_id}", websocket_handler) + + client = await aiohttp_client(app) + + vm_client = VmClient( + account=account, + node_url=str(client.make_url("/")).rstrip("/"), + session=client.session, + ) + + logs = [] + async for log in vm_client.get_logs(vm_id): + logs.append(log) + if log == "mock_log_entry": + break + + assert logs == ["mock_log_entry"] + await vm_client.session.close() From 5c61b9b5af20a9a8ee85bc582ab1790136a996c4 Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 19:16:44 +0200 Subject: [PATCH 09/34] Fix: black in aleph_vm_authentification.py fix: isort issue Fix: mypy issue Fix: black Fix: isort --- src/aleph/sdk/client/vmclient.py | 4 +++- tests/unit/aleph_vm_authentication.py | 12 ++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index e5ff0708..440c24ea 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -98,7 +98,9 @@ async def _generate_header( signed_operation = self.sign_payload(payload, self.ephemeral_key) if not self.pubkey_signature_header: - self.pubkey_signature_header = await self.generate_pubkey_signature_header() + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) headers = { "X-SignedPubKey": self.pubkey_signature_header, diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index c03ce211..2ca144a8 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Awaitable, Coroutine -from typing import Any, Callable, Literal, Union +from typing import Any, Callable, Dict, Literal, Union import cryptography.exceptions import pydantic @@ -15,11 +15,11 @@ from jwcrypto.jwa import JWA from pydantic import BaseModel, ValidationError, root_validator, validator - logger = logging.getLogger(__name__) DOMAIN_NAME = "localhost" + def is_token_still_valid(datestr: str): """ Checks if a token has expired based on its expiry timestamp @@ -42,7 +42,7 @@ def verify_wallet_signature(signature, message, address): class SignedPubKeyPayload(BaseModel): """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" - pubkey: dict[str, Any] + pubkey: Dict[str, Any] # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} # alg: Literal["ECDSA"] @@ -63,7 +63,7 @@ class SignedPubKeyHeader(BaseModel): @validator("signature") def signature_must_be_hex(cls, v: bytes) -> bytes: """Convert the signature from hexadecimal to bytes""" - return bytes.fromhex(v.removeprefix(b"0x").decode()) + return bytes.fromhex(v.decode()) @validator("payload") def payload_must_be_hex(cls, v: bytes) -> bytes: @@ -71,7 +71,7 @@ def payload_must_be_hex(cls, v: bytes) -> bytes: return bytes.fromhex(v.decode()) @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values) -> dict[str, bytes]: + def check_expiry(cls, values) -> Dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) @@ -81,7 +81,7 @@ def check_expiry(cls, values) -> dict[str, bytes]: return values @root_validator(pre=False, skip_on_failure=True) - def check_signature(cls, values) -> dict[str, bytes]: + def check_signature(cls, values) -> Dict[str, bytes]: """Check that the signature is valid""" signature: bytes = values["signature"] payload: bytes = values["payload"] From a30f6909886e7d308b5ed3b536f65e5efe104bc4 Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 19:53:07 +0200 Subject: [PATCH 10/34] Fix: fully remove _generate_header call in get_logs Fix Fix: using real path server instead of fake server for test Fix: create playload --- src/aleph/sdk/client/vmclient.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 440c24ea..226c443c 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -67,9 +67,7 @@ async def _generate_pubkey_signature_header(self) -> str: def create_payload(self, vm_id: str, operation: str) -> Dict[str, str]: path = ( - f"/logs/{vm_id}" - if operation == "logs" - else f"/control/machine/{vm_id}/{operation}" + f"/control/machine/{vm_id}/{operation}" ) payload = { "time": datetime.datetime.utcnow().isoformat() + "Z", @@ -134,8 +132,8 @@ async def get_logs(self, vm_id): payload = self.create_payload(vm_id, "logs") signed_operation = self.sign_payload(payload, self.ephemeral_key) - - ws_url, header = await self._generate_header(vm_id=vm_id, operation="logs") + path = payload["path"] + ws_url = f"{self.node_url}{path}" async with self.session.ws_connect(ws_url) as ws: auth_message = { From fa998aea985bf23a99acdaf6bd5ec57b602cbcc9 Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 19 Jun 2024 20:02:42 +0200 Subject: [PATCH 11/34] Fix: black issue --- src/aleph/sdk/client/vmclient.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 226c443c..74434790 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -66,9 +66,7 @@ async def _generate_pubkey_signature_header(self) -> str: ) def create_payload(self, vm_id: str, operation: str) -> Dict[str, str]: - path = ( - f"/control/machine/{vm_id}/{operation}" - ) + path = f"/control/machine/{vm_id}/{operation}" payload = { "time": datetime.datetime.utcnow().isoformat() + "Z", "method": "POST", From 49c81b5cbae7c5031e7101cc4f8dbac4225be64d Mon Sep 17 00:00:00 2001 From: 1yam Date: Thu, 20 Jun 2024 11:23:06 +0200 Subject: [PATCH 12/34] Fix: test fix workflow --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 81962db8..125bcec1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,6 @@ dependencies = [ "pytest-cov==4.1.0", "pytest-mock==3.12.0", "pytest-asyncio==0.23.5", - "jwcrypto==1.5.6", "aioresponses==0.7.6", "fastapi", "httpx", From 149778b949032cc5094e994753481ed37c4a8ccc Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 18:52:50 +0200 Subject: [PATCH 13/34] feat(vm_client): add missing types annotations --- src/aleph/sdk/client/vmclient.py | 28 ++++++++++++++------------- tests/unit/aleph_vm_authentication.py | 8 ++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 74434790..353c3bc4 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -1,7 +1,7 @@ import datetime import json import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, List, AsyncGenerator from urllib.parse import urlparse import aiohttp @@ -9,6 +9,8 @@ from jwcrypto import jwk from jwcrypto.jwa import JWA +from aleph_message.models import ItemHash + from aleph.sdk.types import Account from aleph.sdk.utils import to_0x_hex @@ -65,7 +67,7 @@ async def _generate_pubkey_signature_header(self) -> str: } ) - def create_payload(self, vm_id: str, operation: str) -> Dict[str, str]: + def create_payload(self, vm_id: ItemHash, operation: str) -> Dict[str, str]: path = f"/control/machine/{vm_id}/{operation}" payload = { "time": datetime.datetime.utcnow().isoformat() + "Z", @@ -88,7 +90,7 @@ def sign_payload(self, payload: Dict[str, str], ephemeral_key) -> str: return signed_operation async def _generate_header( - self, vm_id: str, operation: str + self, vm_id: ItemHash, operation: str ) -> Tuple[str, Dict[str, str]]: payload = self.create_payload(vm_id, operation) signed_operation = self.sign_payload(payload, self.ephemeral_key) @@ -106,7 +108,7 @@ async def _generate_header( path = payload["path"] return f"{self.node_url}{path}", headers - async def perform_operation(self, vm_id, operation): + async def perform_operation(self, vm_id: ItemHash, operation: str) -> Tuple[Optional[int], str]: if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() @@ -118,11 +120,12 @@ async def perform_operation(self, vm_id, operation): async with self.session.post(url, headers=header) as response: response_text = await response.text() return response.status, response_text + except aiohttp.ClientError as e: logger.error(f"HTTP error during operation {operation}: {str(e)}") return None, str(e) - async def get_logs(self, vm_id): + async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() @@ -147,23 +150,22 @@ async def get_logs(self, vm_id): elif msg.type == aiohttp.WSMsgType.ERROR: break - async def start_instance(self, vm_id): + async def start_instance(self, vm_id: ItemHash) -> Tuple[int, str]: return await self.notify_allocation(vm_id) - async def stop_instance(self, vm_id): + async def stop_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: return await self.perform_operation(vm_id, "stop") - async def reboot_instance(self, vm_id): - + async def reboot_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: return await self.perform_operation(vm_id, "reboot") - async def erase_instance(self, vm_id): + async def erase_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: return await self.perform_operation(vm_id, "erase") - async def expire_instance(self, vm_id): + async def expire_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: return await self.perform_operation(vm_id, "expire") - async def notify_allocation(self, vm_id) -> Tuple[Any, str]: + async def notify_allocation(self, vm_id: ItemHash) -> Tuple[int, str]: json_data = {"instance": vm_id} async with self.session.post( f"{self.node_url}/control/allocation/notify", json=json_data @@ -171,7 +173,7 @@ async def notify_allocation(self, vm_id) -> Tuple[Any, str]: form_response_text = await s.text() return s.status, form_response_text - async def manage_instance(self, vm_id, operations): + async def manage_instance(self, vm_id: ItemHash, operations: List[str]): for operation in operations: status, response = await self.perform_operation(vm_id, operation) if status != 200: diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 2ca144a8..e695c002 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -20,7 +20,7 @@ DOMAIN_NAME = "localhost" -def is_token_still_valid(datestr: str): +def is_token_still_valid(datestr: str) -> bool: """ Checks if a token has expired based on its expiry timestamp """ @@ -30,7 +30,7 @@ def is_token_still_valid(datestr: str): return expiry_datetime > current_datetime -def verify_wallet_signature(signature, message, address): +def verify_wallet_signature(signature: bytes, message: str, address: str) -> bool: """ Verifies a signature issued by a wallet """ @@ -71,7 +71,7 @@ def payload_must_be_hex(cls, v: bytes) -> bytes: return bytes.fromhex(v.decode()) @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values) -> Dict[str, bytes]: + def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) @@ -81,7 +81,7 @@ def check_expiry(cls, values) -> Dict[str, bytes]: return values @root_validator(pre=False, skip_on_failure=True) - def check_signature(cls, values) -> Dict[str, bytes]: + def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: """Check that the signature is valid""" signature: bytes = values["signature"] payload: bytes = values["payload"] From 017bf011a2a0f4d148d344368ddc1cea5a06fa76 Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 18:53:47 +0200 Subject: [PATCH 14/34] refactor(vm_client): remove duplicated types annotations --- src/aleph/sdk/client/vmclient.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 353c3bc4..3b4871a9 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -31,11 +31,11 @@ def __init__( node_url: str = "", session: Optional[aiohttp.ClientSession] = None, ): - self.account: Account = account - self.ephemeral_key: jwk.JWK = jwk.JWK.generate(kty="EC", crv="P-256") - self.node_url: str = node_url + self.account = account + self.ephemeral_key = jwk.JWK.generate(kty="EC", crv="P-256") + self.node_url = node_url self.pubkey_payload = self._generate_pubkey_payload() - self.pubkey_signature_header: str = "" + self.pubkey_signature_header = "" self.session = session or aiohttp.ClientSession() def _generate_pubkey_payload(self) -> Dict[str, Any]: From 7ec6c421fdbf917b505ded33f0c5b6fb9a94dc03 Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 18:54:56 +0200 Subject: [PATCH 15/34] refactor(vm_client): avoid using single letter variable names --- src/aleph/sdk/client/vmclient.py | 7 +++--- tests/unit/aleph_vm_authentication.py | 35 ++++++++++++++++----------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 3b4871a9..ff5cd04f 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -169,9 +169,10 @@ async def notify_allocation(self, vm_id: ItemHash) -> Tuple[int, str]: json_data = {"instance": vm_id} async with self.session.post( f"{self.node_url}/control/allocation/notify", json=json_data - ) as s: - form_response_text = await s.text() - return s.status, form_response_text + ) as session: + form_response_text = await session.text() + + return session.status, form_response_text async def manage_instance(self, vm_id: ItemHash, operations: List[str]): for operation in operations: diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index e695c002..4e3d7380 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -61,15 +61,16 @@ class SignedPubKeyHeader(BaseModel): payload: bytes @validator("signature") - def signature_must_be_hex(cls, v: bytes) -> bytes: + def signature_must_be_hex(cls, value: bytes) -> bytes: """Convert the signature from hexadecimal to bytes""" - return bytes.fromhex(v.decode()) + + return bytes.fromhex(value.decode()) @validator("payload") - def payload_must_be_hex(cls, v: bytes) -> bytes: + def payload_must_be_hex(cls, value: bytes) -> bytes: """Convert the payload from hexadecimal to bytes""" - return bytes.fromhex(v.decode()) + return bytes.fromhex(value.decode()) @root_validator(pre=False, skip_on_failure=True) def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: """Check that the token has not expired""" @@ -104,7 +105,7 @@ class SignedOperationPayload(BaseModel): # body_sha256: str # disabled since there is no body @validator("time") - def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: + def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: """Check that the time is current and the payload is not a replay attack.""" max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( minutes=2 @@ -112,11 +113,14 @@ def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: max_future = datetime.datetime.now( tz=datetime.timezone.utc ) + datetime.timedelta(minutes=2) - if v < max_past: + + if value < max_past: raise ValueError("Time is too far in the past") - if v > max_future: + + if value > max_future: raise ValueError("Time is too far in the future") - return v + + return value class SignedOperation(BaseModel): @@ -126,19 +130,22 @@ class SignedOperation(BaseModel): payload: bytes @validator("signature") - def signature_must_be_hex(cls, v) -> bytes: + def signature_must_be_hex(cls, value: str) -> bytes: """Convert the signature from hexadecimal to bytes""" + try: - return bytes.fromhex(v.removeprefix(b"0x").decode()) + return bytes.fromhex(value.removeprefix(b"0x").decode()) + except pydantic.ValidationError as error: - print(v) - logger.warning(v) + print(value) + logger.warning(value) raise error @validator("payload") - def payload_must_be_hex(cls, v) -> bytes: + def payload_must_be_hex(cls, value: bytes) -> bytes: """Convert the payload from hexadecimal to bytes""" - v = bytes.fromhex(v.decode()) + + v = bytes.fromhex(value.decode()) _ = SignedOperationPayload.parse_raw(v) return v From 93dbb223dce1e11b864d20c53e686cdc8dc10351 Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 18:55:24 +0200 Subject: [PATCH 16/34] feat(vm_client): increase test_notify_allocation precision --- tests/unit/test_vmclient.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index 8733954d..d8f4921e 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -1,5 +1,8 @@ import aiohttp import pytest + +from yarl import URL + from aiohttp import web from aioresponses import aioresponses from aleph_message.models import ItemHash @@ -21,7 +24,8 @@ async def test_notify_allocation(): ) m.post("http://localhost/control/allocation/notify", status=200) await vm_client.notify_allocation(vm_id=vm_id) - assert m.requests + assert len(m.requests) == 1 + assert ('POST', URL('http://localhost/control/allocation/notify')) in m.requests await vm_client.session.close() From f93f202c7cb06e54972b80a3811379adf82045d5 Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 18:56:13 +0200 Subject: [PATCH 17/34] refactor(vm_client): add empty lines for code readability --- src/aleph/sdk/client/vmclient.py | 2 ++ src/aleph/sdk/types.py | 1 + tests/unit/aleph_vm_authentication.py | 31 +++++++++++++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index ff5cd04f..6c44973b 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -144,6 +144,7 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: } } await ws.send_json(auth_message) + async for msg in ws: # msg is of type aiohttp.WSMessage if msg.type == aiohttp.WSMsgType.TEXT: yield msg.data @@ -167,6 +168,7 @@ async def expire_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: async def notify_allocation(self, vm_id: ItemHash) -> Tuple[int, str]: json_data = {"instance": vm_id} + async with self.session.post( f"{self.node_url}/control/allocation/notify", json=json_data ) as session: diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index d344123d..71bc2b53 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -22,6 +22,7 @@ async def sign_message(self, message: Dict) -> Dict: ... @abstractmethod async def sign_raw(self, buffer: bytes) -> bytes: ... + @abstractmethod def get_address(self) -> str: ... diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 4e3d7380..13d5cb89 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -36,6 +36,7 @@ def verify_wallet_signature(signature: bytes, message: str, address: str) -> boo """ enc_msg = encode_defunct(hexstr=message) computed_address = Account.recover_message(enc_msg, signature=signature) + return computed_address.lower() == address.lower() @@ -53,6 +54,7 @@ class SignedPubKeyPayload(BaseModel): @property def json_web_key(self) -> jwk.JWK: """Return the ephemeral public key as Json Web Key""" + return jwk.JWK(**self.pubkey) @@ -71,14 +73,17 @@ def payload_must_be_hex(cls, value: bytes) -> bytes: """Convert the payload from hexadecimal to bytes""" return bytes.fromhex(value.decode()) + @root_validator(pre=False, skip_on_failure=True) def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) + if not is_token_still_valid(content.expires): msg = "Token expired" raise ValueError(msg) + return values @root_validator(pre=False, skip_on_failure=True) @@ -87,14 +92,17 @@ def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: signature: bytes = values["signature"] payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) + if not verify_wallet_signature(signature, payload.hex(), content.address): msg = "Invalid signature" raise ValueError(msg) + return values @property def content(self) -> SignedPubKeyPayload: """Return the content of the header""" + return SignedPubKeyPayload.parse_raw(self.payload) @@ -139,6 +147,7 @@ def signature_must_be_hex(cls, value: str) -> bytes: except pydantic.ValidationError as error: print(value) logger.warning(value) + raise error @validator("payload") @@ -147,6 +156,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes: v = bytes.fromhex(value.decode()) _ = SignedOperationPayload.parse_raw(v) + return v @property @@ -158,27 +168,35 @@ def content(self) -> SignedOperationPayload: def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: """Get the ephemeral public key that is signed by the wallet from the request headers.""" signed_pubkey_header = request.headers.get("X-SignedPubKey") + if not signed_pubkey_header: raise web.HTTPBadRequest(reason="Missing X-SignedPubKey header") try: return SignedPubKeyHeader.parse_raw(signed_pubkey_header) + except KeyError as error: logger.debug(f"Missing X-SignedPubKey header: {error}") raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey fields") from error + except json.JSONDecodeError as error: raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from error + except ValueError as errors: logging.debug(errors) + for err in errors.args[0]: if isinstance(err.exc, json.JSONDecodeError): raise web.HTTPBadRequest( reason="Invalid X-SignedPubKey format" ) from errors + if str(err.exc) == "Token expired": raise web.HTTPUnauthorized(reason="Token expired") from errors + if str(err.exc) == "Invalid signature": raise web.HTTPUnauthorized(reason="Invalid signature") from errors + else: raise errors @@ -188,10 +206,13 @@ def get_signed_operation(request: web.Request) -> SignedOperation: try: signed_operation = request.headers["X-SignedOperation"] return SignedOperation.parse_raw(signed_operation) + except KeyError as error: raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error + except json.JSONDecodeError as error: raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error + except ValidationError as error: logger.debug(f"Invalid X-SignedOperation fields: {error}") raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error @@ -208,9 +229,12 @@ def verify_signed_operation( pubkey, signed_operation.payload, signed_operation.signature ) logger.debug("Signature verified") + return signed_pubkey.content.address + except cryptography.exceptions.InvalidSignature as e: logger.debug("Failing to validate signature for operation", e) + raise web.HTTPUnauthorized(reason="Signature could not verified") @@ -218,21 +242,25 @@ async def authenticate_jwk(request: web.Request) -> str: """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) + if signed_pubkey.content.domain != DOMAIN_NAME: logger.debug( f"Invalid domain '{signed_pubkey.content.domain}' != '{DOMAIN_NAME}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") + if signed_operation.content.path != request.path: logger.debug( f"Invalid path '{signed_operation.content.path}' != '{request.path}'" ) raise web.HTTPUnauthorized(reason="Invalid path") + if signed_operation.content.method != request.method: logger.debug( f"Invalid method '{signed_operation.content.method}' != '{request.method}'" ) raise web.HTTPUnauthorized(reason="Invalid method") + return verify_signed_operation(signed_operation, signed_pubkey) @@ -240,17 +268,20 @@ async def authenticate_websocket_message(message) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) + if signed_pubkey.content.node_url != DOMAIN_NAME: logger.debug( f"Invalid domain '{signed_pubkey.content.node_url}' != '{DOMAIN_NAME}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") + return verify_signed_operation(signed_operation, signed_pubkey) def require_jwk_authentication( handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: + @functools.wraps(handler) async def wrapper(request): try: From 3abcf36bfe0732110ee2a489f0337e98a7979378 Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 19:30:56 +0200 Subject: [PATCH 18/34] style: run linting:fmt --- src/aleph/sdk/client/vmclient.py | 9 +++++---- tests/unit/test_vmclient.py | 6 ++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 6c44973b..8c61ab48 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -1,16 +1,15 @@ import datetime import json import logging -from typing import Any, Dict, Optional, Tuple, List, AsyncGenerator +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import aiohttp +from aleph_message.models import ItemHash from eth_account.messages import encode_defunct from jwcrypto import jwk from jwcrypto.jwa import JWA -from aleph_message.models import ItemHash - from aleph.sdk.types import Account from aleph.sdk.utils import to_0x_hex @@ -108,7 +107,9 @@ async def _generate_header( path = payload["path"] return f"{self.node_url}{path}", headers - async def perform_operation(self, vm_id: ItemHash, operation: str) -> Tuple[Optional[int], str]: + async def perform_operation( + self, vm_id: ItemHash, operation: str + ) -> Tuple[Optional[int], str]: if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index d8f4921e..d62fe611 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -1,11 +1,9 @@ import aiohttp import pytest - -from yarl import URL - from aiohttp import web from aioresponses import aioresponses from aleph_message.models import ItemHash +from yarl import URL from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.client.vmclient import VmClient @@ -25,7 +23,7 @@ async def test_notify_allocation(): m.post("http://localhost/control/allocation/notify", status=200) await vm_client.notify_allocation(vm_id=vm_id) assert len(m.requests) == 1 - assert ('POST', URL('http://localhost/control/allocation/notify')) in m.requests + assert ("POST", URL("http://localhost/control/allocation/notify")) in m.requests await vm_client.session.close() From ca16c5ab65bf4ef5bb952fca2efaae5a37bcb6e4 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Fri, 21 Jun 2024 12:40:19 +0200 Subject: [PATCH 19/34] Fix: Required an old version of `aleph-message` --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 125bcec1..4330ae44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "aiohttp>=3.8.3", - "aleph-message~=0.4.4", + "aleph-message>=0.4.7", "coincurve; python_version<\"3.11\"", "coincurve>=19.0.0; python_version>=\"3.11\"", "eth_abi>=4.0.0; python_version>=\"3.11\"", From 5162096f43ffd3d06d845141ab78d14be4cc3d69 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Fri, 21 Jun 2024 13:13:03 +0200 Subject: [PATCH 20/34] Fix: Newer aleph-message requires InstanceEnvironment Else tests were breaking. --- src/aleph/sdk/client/authenticated_http.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 60d42b2b..bf34e225 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -30,6 +30,7 @@ from aleph_message.models.execution.environment import ( FunctionEnvironment, HypervisorType, + InstanceEnvironment, MachineResources, ) from aleph_message.models.execution.instance import RootfsVolume @@ -539,8 +540,7 @@ async def create_instance( content = InstanceContent( address=address, allow_amend=allow_amend, - environment=FunctionEnvironment( - reproducible=False, + environment=InstanceEnvironment( internet=internet, aleph_api=aleph_api, hypervisor=hypervisor, From 65a0dfe53918232886a7d7af2e970fc565c2b632 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Fri, 21 Jun 2024 13:13:20 +0200 Subject: [PATCH 21/34] Fix: Qemu was not the default hypervisor for instances. --- src/aleph/sdk/client/authenticated_http.py | 6 ++++-- tests/unit/test_asynchronous.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index bf34e225..6d44b526 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -535,7 +535,9 @@ async def create_instance( timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold) - hypervisor = hypervisor or HypervisorType.firecracker + + # Default to the QEMU hypervisor for instances. + selected_hypervisor: HypervisorType = hypervisor or HypervisorType.qemu content = InstanceContent( address=address, @@ -543,7 +545,7 @@ async def create_instance( environment=InstanceEnvironment( internet=internet, aleph_api=aleph_api, - hypervisor=hypervisor, + hypervisor=selected_hypervisor, ), variables=environment_variables, resources=MachineResources( diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index 0fa0df38..0f909408 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -157,7 +157,7 @@ async def test_create_instance_no_hypervisor(mock_session_with_post_success): hypervisor=None, ) - assert instance_message.content.environment.hypervisor == HypervisorType.firecracker + assert instance_message.content.environment.hypervisor == HypervisorType.qemu assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(instance_message, InstanceMessage) From 225b42a01d0ffddefa1b0a7dfb3da9a7228cfafc Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Fri, 21 Jun 2024 14:42:38 +0200 Subject: [PATCH 22/34] Fix: Pythom 3.12 fails setup libsecp256k1 When "using bundled libsecp256k1", the setup using `/tmp/venv/bin/hatch run testing:test` fails to proceed on Python 3.12. That library `secp256k1` has been unmaintained for more than 2 years now (0.14.0, Nov 6, 2021), and seems to not support Python 3.12. The error in the logs: ``` File "/tmp/pip-build-env-ye8d6ort/overlay/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 862, in get_command_obj cmd_obj = self.command_obj[command] = klass(self) ^^^^^^^^^^^ TypeError: 'NoneType' object is not callable [end of output] ``` See failing CI run: https://github.com/aleph-im/aleph-sdk-python/actions/runs/9613634583/job/26516767722 --- .github/workflows/pytest.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index efe667ea..8d5456c6 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -14,8 +14,11 @@ on: jobs: build: strategy: + fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11" ] + # An issue with secp256k1 prevents Python 3.12 from working + # See https://github.com/baking-bad/pytezos/issues/370 runs-on: ubuntu-latest steps: From 93bffa98307c107c6b850a5cf6fa2149ec459112 Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 25 Jun 2024 17:17:34 +0200 Subject: [PATCH 23/34] doc(README): command to launch tests was incorrect --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cfc7e1a4..3d2aea9c 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ $ pip install -e .[all] You can use the test env defined for hatch to run the tests: ```shell -$ hatch run test:run +$ hatch run testing:run ``` See `hatch env show` for more information about all the environments and their scripts. From b53505df94b9565b02e38b875526e8db5df882c0 Mon Sep 17 00:00:00 2001 From: 1yam Date: Fri, 28 Jun 2024 15:38:36 +0200 Subject: [PATCH 24/34] Refactor: create and sign playload goes to utils and some fix --- pyproject.toml | 1 + src/aleph/sdk/client/vmclient.py | 43 ++++++++++---------------------- src/aleph/sdk/utils.py | 28 ++++++++++++++++++++- tests/unit/test_vmclient.py | 17 +++++++++++-- 4 files changed, 56 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4330ae44..8874c3f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ dependencies = [ "pytest-cov==4.1.0", "pytest-mock==3.12.0", "pytest-asyncio==0.23.5", + "pytest-aiohttp==1.0.5", "aioresponses==0.7.6", "fastapi", "httpx", diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 8c61ab48..a2dd8473 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -8,10 +8,13 @@ from aleph_message.models import ItemHash from eth_account.messages import encode_defunct from jwcrypto import jwk -from jwcrypto.jwa import JWA from aleph.sdk.types import Account -from aleph.sdk.utils import to_0x_hex +from aleph.sdk.utils import ( + create_vm_control_payload, + sign_vm_control_payload, + to_0x_hex, +) logger = logging.getLogger(__name__) @@ -66,33 +69,11 @@ async def _generate_pubkey_signature_header(self) -> str: } ) - def create_payload(self, vm_id: ItemHash, operation: str) -> Dict[str, str]: - path = f"/control/machine/{vm_id}/{operation}" - payload = { - "time": datetime.datetime.utcnow().isoformat() + "Z", - "method": "POST", - "path": path, - } - return payload - - def sign_payload(self, payload: Dict[str, str], ephemeral_key) -> str: - payload_as_bytes = json.dumps(payload).encode("utf-8") - payload_signature = JWA.signing_alg("ES256").sign( - ephemeral_key, payload_as_bytes - ) - signed_operation = json.dumps( - { - "payload": payload_as_bytes.hex(), - "signature": payload_signature.hex(), - } - ) - return signed_operation - async def _generate_header( self, vm_id: ItemHash, operation: str ) -> Tuple[str, Dict[str, str]]: - payload = self.create_payload(vm_id, operation) - signed_operation = self.sign_payload(payload, self.ephemeral_key) + payload = create_vm_control_payload(vm_id, operation) + signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) if not self.pubkey_signature_header: self.pubkey_signature_header = ( @@ -132,8 +113,8 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: await self._generate_pubkey_signature_header() ) - payload = self.create_payload(vm_id, "logs") - signed_operation = self.sign_payload(payload, self.ephemeral_key) + payload = create_vm_control_payload(vm_id, "logs") + signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) path = payload["path"] ws_url = f"{self.node_url}{path}" @@ -177,12 +158,14 @@ async def notify_allocation(self, vm_id: ItemHash) -> Tuple[int, str]: return session.status, form_response_text - async def manage_instance(self, vm_id: ItemHash, operations: List[str]): + async def manage_instance( + self, vm_id: ItemHash, operations: List[str] + ) -> Tuple[int, str]: for operation in operations: status, response = await self.perform_operation(vm_id, operation) if status != 200: return status, response - return + return 200, "All operations completed successfully" async def close(self): await self.session.close() diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index ffd7ac6b..2a7ce1fb 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,5 +1,7 @@ +import datetime import errno import hashlib +import json import logging import os from datetime import date, datetime, time @@ -8,6 +10,7 @@ from shutil import make_archive from typing import ( Any, + Dict, Iterable, Mapping, Optional, @@ -20,9 +23,10 @@ ) from zipfile import BadZipFile, ZipFile -from aleph_message.models import MessageType +from aleph_message.models import ItemHash, MessageType from aleph_message.models.execution.program import Encoding from aleph_message.models.execution.volume import MachineVolume +from jwcrypto.jwa import JWA from pydantic.json import pydantic_encoder from aleph.sdk.conf import settings @@ -195,3 +199,25 @@ def bytes_from_hex(hex_string: str) -> bytes: hex_string = hex_string[2:] hex_string = bytes.fromhex(hex_string) return hex_string + + +def create_vm_control_payload(vm_id: ItemHash, operation: str) -> Dict[str, str]: + path = f"/control/machine/{vm_id}/{operation}" + payload = { + "time": datetime.utcnow().isoformat() + "Z", + "method": "POST", + "path": path, + } + return payload + + +def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: + payload_as_bytes = json.dumps(payload).encode("utf-8") + payload_signature = JWA.signing_alg("ES256").sign(ephemeral_key, payload_as_bytes) + signed_operation = json.dumps( + { + "payload": payload_as_bytes.hex(), + "signature": payload_signature.hex(), + } + ) + return signed_operation diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index d62fe611..212c89c9 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -161,16 +161,29 @@ async def websocket_handler(request): return ws app = web.Application() - app.router.add_route("GET", "/logs/{vm_id}", websocket_handler) + app.router.add_route( + "GET", "/control/machine/{vm_id}/logs", websocket_handler + ) # Update route to match the URL client = await aiohttp_client(app) + node_url = str(client.make_url("")).rstrip("/") + vm_client = VmClient( account=account, - node_url=str(client.make_url("/")).rstrip("/"), + node_url=node_url, session=client.session, ) + original_get_logs = vm_client.get_logs + + async def debug_get_logs(vm_id): + url = f"{vm_client.node_url}/control/machine/{vm_id}/logs" + async for log in original_get_logs(vm_id): + yield log + + vm_client.get_logs = debug_get_logs + logs = [] async for log in vm_client.get_logs(vm_id): logs.append(log) From 3c66af0a9589b09e11cbe525b9ddb3bb06e7708f Mon Sep 17 00:00:00 2001 From: 1yam Date: Fri, 28 Jun 2024 15:55:54 +0200 Subject: [PATCH 25/34] Fix: linting issue --- src/aleph/sdk/client/vmclient.py | 2 +- src/aleph/sdk/utils.py | 1 - tests/unit/test_vmclient.py | 9 --------- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index a2dd8473..212b7eb5 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -163,7 +163,7 @@ async def manage_instance( ) -> Tuple[int, str]: for operation in operations: status, response = await self.perform_operation(vm_id, operation) - if status != 200: + if status != 200 and status: return status, response return 200, "All operations completed successfully" diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 2a7ce1fb..2d1b30c7 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,4 +1,3 @@ -import datetime import errno import hashlib import json diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index 212c89c9..87809221 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -175,15 +175,6 @@ async def websocket_handler(request): session=client.session, ) - original_get_logs = vm_client.get_logs - - async def debug_get_logs(vm_id): - url = f"{vm_client.node_url}/control/machine/{vm_id}/logs" - async for log in original_get_logs(vm_id): - yield log - - vm_client.get_logs = debug_get_logs - logs = [] async for log in vm_client.get_logs(vm_id): logs.append(log) From 0c62cd550b7980be95bfc6e7883ea825df0ff967 Mon Sep 17 00:00:00 2001 From: 1yam Date: Fri, 28 Jun 2024 16:04:20 +0200 Subject: [PATCH 26/34] Fix: mypy issue --- tests/unit/aleph_vm_authentication.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 13d5cb89..7f6704c1 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -137,12 +137,13 @@ class SignedOperation(BaseModel): signature: bytes payload: bytes + @validator("signature") def signature_must_be_hex(cls, value: str) -> bytes: """Convert the signature from hexadecimal to bytes""" try: - return bytes.fromhex(value.removeprefix(b"0x").decode()) + return bytes.fromhex(value[2:] if value.startswith("0x") else value) except pydantic.ValidationError as error: print(value) @@ -269,9 +270,9 @@ async def authenticate_websocket_message(message) -> str: signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) - if signed_pubkey.content.node_url != DOMAIN_NAME: + if signed_pubkey.content.domain != DOMAIN_NAME: logger.debug( - f"Invalid domain '{signed_pubkey.content.node_url}' != '{DOMAIN_NAME}'" + f"Invalid domain '{signed_pubkey.content.domain}' != '{DOMAIN_NAME}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") From 0fab7c3b46043af0eb6a4a38240ae1ab493440b6 Mon Sep 17 00:00:00 2001 From: 1yam Date: Fri, 28 Jun 2024 16:06:00 +0200 Subject: [PATCH 27/34] fix: black --- tests/unit/aleph_vm_authentication.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 7f6704c1..69b8a2d3 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -137,7 +137,6 @@ class SignedOperation(BaseModel): signature: bytes payload: bytes - @validator("signature") def signature_must_be_hex(cls, value: str) -> bytes: """Convert the signature from hexadecimal to bytes""" From 5380da5cb51764cef524bc450037aad343351e9e Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 2 Jul 2024 15:38:56 +0200 Subject: [PATCH 28/34] feat: use bytes_from_hex where it makes sens --- tests/unit/aleph_vm_authentication.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 69b8a2d3..a29b7d85 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -15,6 +15,9 @@ from jwcrypto.jwa import JWA from pydantic import BaseModel, ValidationError, root_validator, validator +from aleph.sdk.utils import bytes_from_hex + + logger = logging.getLogger(__name__) DOMAIN_NAME = "localhost" @@ -66,13 +69,13 @@ class SignedPubKeyHeader(BaseModel): def signature_must_be_hex(cls, value: bytes) -> bytes: """Convert the signature from hexadecimal to bytes""" - return bytes.fromhex(value.decode()) + return bytes_from_hex(value.decode()) @validator("payload") def payload_must_be_hex(cls, value: bytes) -> bytes: """Convert the payload from hexadecimal to bytes""" - return bytes.fromhex(value.decode()) + return bytes_from_hex(value.decode()) @root_validator(pre=False, skip_on_failure=True) def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: @@ -142,7 +145,9 @@ def signature_must_be_hex(cls, value: str) -> bytes: """Convert the signature from hexadecimal to bytes""" try: - return bytes.fromhex(value[2:] if value.startswith("0x") else value) + if isinstance(value, bytes): + value = value.decode() + return bytes_from_hex(value) except pydantic.ValidationError as error: print(value) @@ -154,7 +159,7 @@ def signature_must_be_hex(cls, value: str) -> bytes: def payload_must_be_hex(cls, value: bytes) -> bytes: """Convert the payload from hexadecimal to bytes""" - v = bytes.fromhex(value.decode()) + v = bytes_from_hex(value.decode()) _ = SignedOperationPayload.parse_raw(v) return v From fc1e6af8b755a146850da90ce9bd15e533b61f1a Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 2 Jul 2024 16:34:01 +0200 Subject: [PATCH 29/34] chore: use ruff new CLI api --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8874c3f0..b52efe66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,13 +153,13 @@ dependencies = [ [tool.hatch.envs.linting.scripts] typing = "mypy --config-file=pyproject.toml {args:} ./src/ ./tests/ ./examples/" style = [ - "ruff {args:.} ./src/ ./tests/ ./examples/", + "ruff check {args:.} ./src/ ./tests/ ./examples/", "black --check --diff {args:} ./src/ ./tests/ ./examples/", "isort --check-only --profile black {args:} ./src/ ./tests/ ./examples/", ] fmt = [ "black {args:} ./src/ ./tests/ ./examples/", - "ruff --fix {args:.} ./src/ ./tests/ ./examples/", + "ruff check --fix {args:.} ./src/ ./tests/ ./examples/", "isort --profile black {args:} ./src/ ./tests/ ./examples/", "style", ] From 2f180e3a69c80b67a380322500c1d23270cccace Mon Sep 17 00:00:00 2001 From: Laurent Peuch Date: Tue, 2 Jul 2024 16:34:31 +0200 Subject: [PATCH 30/34] feat: add unit tests for authentication mechanisms of VmClient --- tests/unit/aleph_vm_authentication.py | 15 ++-- tests/unit/test_vmclient.py | 113 ++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 7 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index a29b7d85..69451e93 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -17,7 +17,6 @@ from aleph.sdk.utils import bytes_from_hex - logger = logging.getLogger(__name__) DOMAIN_NAME = "localhost" @@ -243,14 +242,14 @@ def verify_signed_operation( raise web.HTTPUnauthorized(reason="Signature could not verified") -async def authenticate_jwk(request: web.Request) -> str: +async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) -> str: """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) - if signed_pubkey.content.domain != DOMAIN_NAME: + if signed_pubkey.content.domain != domain_name: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{DOMAIN_NAME}'" + f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") @@ -269,14 +268,16 @@ async def authenticate_jwk(request: web.Request) -> str: return verify_signed_operation(signed_operation, signed_pubkey) -async def authenticate_websocket_message(message) -> str: +async def authenticate_websocket_message( + message, domain_name: str = DOMAIN_NAME +) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) - if signed_pubkey.content.domain != DOMAIN_NAME: + if signed_pubkey.content.domain != domain_name: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{DOMAIN_NAME}'" + f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index 87809221..d0198c36 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -1,3 +1,6 @@ +import json +from urllib.parse import urlparse + import aiohttp import pytest from aiohttp import web @@ -8,6 +11,14 @@ from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.client.vmclient import VmClient +from .aleph_vm_authentication import ( + SignedOperation, + SignedPubKeyHeader, + authenticate_jwk, + authenticate_websocket_message, + verify_signed_operation, +) + @pytest.mark.asyncio async def test_notify_allocation(): @@ -183,3 +194,105 @@ async def websocket_handler(request): assert logs == ["mock_log_entry"] await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_authenticate_jwk(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def test_authenticate_route(request): + address = await authenticate_jwk(request, domain_name=urlparse(node_url).netloc) + assert vm_client.account.get_address() == address + return web.Response(text="ok") + + app = web.Application() + app.router.add_route( + "POST", f"/control/machine/{vm_id}/stop", test_authenticate_route + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + status_code, response_text = await vm_client.stop_instance(vm_id) + assert status_code == 200 + assert response_text == "ok" + + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_websocket_authentication(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + first_message = await ws.receive_json() + credentials = first_message["auth"] + address = await authenticate_websocket_message( + { + "X-SignedPubKey": json.loads(credentials["X-SignedPubKey"]), + "X-SignedOperation": json.loads(credentials["X-SignedOperation"]), + }, + domain_name=urlparse(node_url).netloc, + ) + + assert vm_client.account.get_address() == address + await ws.send_str(address) + + return ws + + app = web.Application() + app.router.add_route( + "GET", "/control/machine/{vm_id}/logs", websocket_handler + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + valid = False + async for address in vm_client.get_logs(vm_id): + assert address == vm_client.account.get_address() + valid = True + + # this is done to ensure that the ws as runned at least once and avoid + # having silent errors + assert valid + + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_vm_client_generate_correct_authentication_headers(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + + path, headers = await vm_client._generate_header(vm_id, "reboot") + signed_pubkey = SignedPubKeyHeader.parse_raw(headers["X-SignedPubKey"]) + signed_operation = SignedOperation.parse_raw(headers["X-SignedOperation"]) + address = verify_signed_operation(signed_operation, signed_pubkey) + + assert vm_client.account.get_address() == address From d271038e1a1bca2df3fc2bf701d2d247d2b48e2f Mon Sep 17 00:00:00 2001 From: 1yam Date: Wed, 3 Jul 2024 16:48:28 +0200 Subject: [PATCH 31/34] fix: debug code remove --- tests/unit/aleph_vm_authentication.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 69451e93..7d213547 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -147,11 +147,8 @@ def signature_must_be_hex(cls, value: str) -> bytes: if isinstance(value, bytes): value = value.decode() return bytes_from_hex(value) - except pydantic.ValidationError as error: - print(value) logger.warning(value) - raise error @validator("payload") From 5e881613cb5a4be4222a7b520e99c52357fb48d0 Mon Sep 17 00:00:00 2001 From: 1yam <40899431+1yam@users.noreply.github.com> Date: Thu, 4 Jul 2024 11:00:59 +0200 Subject: [PATCH 32/34] Update vmclient.py Co-authored-by: Olivier Le Thanh Duong --- src/aleph/sdk/client/vmclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vmclient.py index 212b7eb5..52e9d152 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vmclient.py @@ -113,7 +113,7 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: await self._generate_pubkey_signature_header() ) - payload = create_vm_control_payload(vm_id, "logs") + payload = create_vm_control_payload(vm_id, "stream_logs") signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) path = payload["path"] ws_url = f"{self.node_url}{path}" From d9b18920405e31638da52de0cdcb51a89fbdd24b Mon Sep 17 00:00:00 2001 From: 1yam Date: Thu, 4 Jul 2024 11:07:07 +0200 Subject: [PATCH 33/34] Fix: update unit test to use stream_logs endpoint instead of logs --- tests/unit/test_vmclient.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vmclient.py index d0198c36..bc201472 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vmclient.py @@ -173,7 +173,7 @@ async def websocket_handler(request): app = web.Application() app.router.add_route( - "GET", "/control/machine/{vm_id}/logs", websocket_handler + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler ) # Update route to match the URL client = await aiohttp_client(app) @@ -254,7 +254,7 @@ async def websocket_handler(request): app = web.Application() app.router.add_route( - "GET", "/control/machine/{vm_id}/logs", websocket_handler + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler ) # Update route to match the URL client = await aiohttp_client(app) From 247dbfcb01c0d2ab8bbe9c4b11854fd94044fff6 Mon Sep 17 00:00:00 2001 From: nesitor Date: Thu, 4 Jul 2024 22:32:19 +0200 Subject: [PATCH 34/34] Implement `VmConfidentialClient` class (#138) * Problem: A user cannot initialize an already created confidential VM. Solution: Implement `VmConfidentialClient` class to be able to initialize and interact with confidential VMs. * Problem: Auth was not working Corrections: * Measurement type returned was missing field needed for validation of measurements * Port number was not handled correctly in authentifaction * Adapt to new auth protocol where domain is moved to the operation field (While keeping compat with the old format) * Get measurement was not working since signed with the wrong method * inject_secret was not sending a json * Websocked auth was sending a twice serialized json * update 'vendorized' aleph-vm auth file from source Co-authored-by: Olivier Le Thanh Duong --- pyproject.toml | 1 + .../sdk/client/{vmclient.py => vm_client.py} | 35 ++- .../sdk/client/vm_confidential_client.py | 216 ++++++++++++++++++ src/aleph/sdk/types.py | 25 ++ src/aleph/sdk/utils.py | 164 ++++++++++++- tests/unit/aleph_vm_authentication.py | 46 ++-- .../{test_vmclient.py => test_vm_client.py} | 25 +- tests/unit/test_vm_confidential_client.py | 216 ++++++++++++++++++ 8 files changed, 673 insertions(+), 55 deletions(-) rename src/aleph/sdk/client/{vmclient.py => vm_client.py} (83%) create mode 100644 src/aleph/sdk/client/vm_confidential_client.py rename tests/unit/{test_vmclient.py => test_vm_client.py} (93%) create mode 100644 tests/unit/test_vm_confidential_client.py diff --git a/pyproject.toml b/pyproject.toml index b52efe66..1070a7f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "python-magic", "typer", "typing_extensions", + "aioresponses>=0.7.6" ] [project.optional-dependencies] diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vm_client.py similarity index 83% rename from src/aleph/sdk/client/vmclient.py rename to src/aleph/sdk/client/vm_client.py index 52e9d152..4092851d 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vm_client.py @@ -44,7 +44,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", - "domain": urlparse(self.node_url).netloc, + "domain": self.node_domain, "address": self.account.get_address(), "expires": ( datetime.datetime.utcnow() + datetime.timedelta(days=1) @@ -65,14 +65,16 @@ async def _generate_pubkey_signature_header(self) -> str: "sender": self.account.get_address(), "payload": pubkey_payload, "signature": pubkey_signature, - "content": {"domain": urlparse(self.node_url).netloc}, + "content": {"domain": self.node_domain}, } ) async def _generate_header( - self, vm_id: ItemHash, operation: str + self, vm_id: ItemHash, operation: str, method: str ) -> Tuple[str, Dict[str, str]]: - payload = create_vm_control_payload(vm_id, operation) + payload = create_vm_control_payload( + vm_id, operation, domain=self.node_domain, method=method + ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) if not self.pubkey_signature_header: @@ -88,18 +90,29 @@ async def _generate_header( path = payload["path"] return f"{self.node_url}{path}", headers + @property + def node_domain(self) -> str: + domain = urlparse(self.node_url).hostname + if not domain: + raise Exception("Could not parse node domain") + return domain + async def perform_operation( - self, vm_id: ItemHash, operation: str + self, vm_id: ItemHash, operation: str, method: str = "POST" ) -> Tuple[Optional[int], str]: if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() ) - url, header = await self._generate_header(vm_id=vm_id, operation=operation) + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method=method + ) try: - async with self.session.post(url, headers=header) as response: + async with self.session.request( + method=method, url=url, headers=header + ) as response: response_text = await response.text() return response.status, response_text @@ -113,7 +126,9 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: await self._generate_pubkey_signature_header() ) - payload = create_vm_control_payload(vm_id, "stream_logs") + payload = create_vm_control_payload( + vm_id, "stream_logs", method="get", domain=self.node_domain + ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) path = payload["path"] ws_url = f"{self.node_url}{path}" @@ -121,8 +136,8 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: async with self.session.ws_connect(ws_url) as ws: auth_message = { "auth": { - "X-SignedPubKey": self.pubkey_signature_header, - "X-SignedOperation": signed_operation, + "X-SignedPubKey": json.loads(self.pubkey_signature_header), + "X-SignedOperation": json.loads(signed_operation), } } await ws.send_json(auth_message) diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py new file mode 100644 index 00000000..a100de8c --- /dev/null +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -0,0 +1,216 @@ +import base64 +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.client.vm_client import VmClient +from aleph.sdk.types import Account, SEVMeasurement +from aleph.sdk.utils import ( + compute_confidential_measure, + encrypt_secret_table, + get_vm_measure, + make_packet_header, + make_secret_table, + run_in_subprocess, +) + +logger = logging.getLogger(__name__) + + +class VmConfidentialClient(VmClient): + sevctl_path: Path + + def __init__( + self, + account: Account, + sevctl_path: Path, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + super().__init__(account, node_url, session) + self.sevctl_path = sevctl_path + + async def get_certificates(self) -> Tuple[Optional[int], str]: + """ + Get platform confidential certificate + """ + + url = f"{self.node_url}/about/certificates" + try: + async with self.session.get(url) as response: + data = await response.read() + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(data) + return response.status, tmp_file.name + + except aiohttp.ClientError as e: + logger.error( + f"HTTP error getting node certificates on {self.node_url}: {str(e)}" + ) + return None, str(e) + + async def create_session( + self, vm_id: ItemHash, certificate_path: Path, policy: int + ) -> Path: + """ + Create new confidential session + """ + + current_path = Path().cwd() + args = [ + "session", + "--name", + str(vm_id), + str(certificate_path), + str(policy), + ] + try: + # TODO: Check command result + await self.sevctl_cmd(*args) + return current_path + except Exception as e: + raise ValueError(f"Session creation have failed, reason: {str(e)}") + + async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str: + """ + Initialize Confidential VM negociation passing the needed session files + """ + + session_file = session.read_bytes() + godh_file = godh.read_bytes() + params = { + "session": session_file, + "godh": godh_file, + } + return await self.perform_confidential_operation( + vm_id, "confidential/initialize", params=params + ) + + async def measurement(self, vm_id: ItemHash) -> SEVMeasurement: + """ + Fetch VM confidential measurement + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + status, text = await self.perform_operation( + vm_id, "confidential/measurement", method="GET" + ) + sev_mesurement = SEVMeasurement.parse_raw(text) + return sev_mesurement + + async def validate_measure( + self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str + ) -> bool: + """ + Validate VM confidential measurement + """ + + tik = tik_path.read_bytes() + vm_measure, nonce = get_vm_measure(sev_data) + + expected_measure = compute_confidential_measure( + sev_info=sev_data.sev_info, + tik=tik, + expected_hash=firmware_hash, + nonce=nonce, + ).digest() + return expected_measure == vm_measure + + async def build_secret( + self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str + ) -> Tuple[str, str]: + """ + Build disk secret to be injected on the confidential VM + """ + + tek = tek_path.read_bytes() + tik = tik_path.read_bytes() + + vm_measure, _ = get_vm_measure(sev_data) + + iv = os.urandom(16) + secret_table = make_secret_table(secret) + encrypted_secret_table = encrypt_secret_table( + secret_table=secret_table, tek=tek, iv=iv + ) + + packet_header = make_packet_header( + vm_measure=vm_measure, + encrypted_secret_table=encrypted_secret_table, + secret_table_size=len(secret_table), + tik=tik, + iv=iv, + ) + + encoded_packet_header = base64.b64encode(packet_header).decode() + encoded_secret = base64.b64encode(encrypted_secret_table).decode() + + return encoded_packet_header, encoded_secret + + async def inject_secret( + self, vm_id: ItemHash, packet_header: str, secret: str + ) -> Dict: + """ + Send the secret by the encrypted channel to boot up the VM + """ + + params = { + "packet_header": packet_header, + "secret": secret, + } + text = await self.perform_confidential_operation( + vm_id, "confidential/inject_secret", json=params + ) + + return json.loads(text) + + async def perform_confidential_operation( + self, + vm_id: ItemHash, + operation: str, + params: Optional[Dict[str, Any]] = None, + json=None, + ) -> str: + """ + Send confidential operations to the CRN passing the auth headers on each request + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method="post" + ) + + try: + async with self.session.post( + url, headers=header, data=params, json=json + ) as response: + response.raise_for_status() + response_text = await response.text() + return response_text + + except aiohttp.ClientError as e: + raise ValueError(f"HTTP error during operation {operation}: {str(e)}") + + async def sevctl_cmd(self, *args) -> bytes: + """ + Execute `sevctl` command with given arguments + """ + + return await run_in_subprocess( + [str(self.sevctl_path), *args], + check=True, + ) diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 71bc2b53..cf9e6fa8 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -2,6 +2,8 @@ from enum import Enum from typing import Dict, Protocol, TypeVar +from pydantic import BaseModel + __all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") from aleph_message.models import AlephMessage @@ -39,3 +41,26 @@ async def sign_raw(self, buffer: bytes) -> bytes: ... GenericMessage = TypeVar("GenericMessage", bound=AlephMessage) + + +class SEVInfo(BaseModel): + """ + An AMD SEV platform information. + """ + + enabled: bool + api_major: int + api_minor: int + build_id: int + policy: int + state: str + handle: int + + +class SEVMeasurement(BaseModel): + """ + A SEV measurement data get from Qemu measurement. + """ + + sev_info: SEVInfo + launch_measure: str diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 2d1b30c7..5c641d5c 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,8 +1,12 @@ +import asyncio +import base64 import errno import hashlib +import hmac import json import logging import os +import subprocess from datetime import date, datetime, time from enum import Enum from pathlib import Path @@ -11,6 +15,7 @@ Any, Dict, Iterable, + List, Mapping, Optional, Protocol, @@ -20,16 +25,19 @@ Union, get_args, ) +from uuid import UUID from zipfile import BadZipFile, ZipFile from aleph_message.models import ItemHash, MessageType from aleph_message.models.execution.program import Encoding from aleph_message.models.execution.volume import MachineVolume +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from jwcrypto.jwa import JWA from pydantic.json import pydantic_encoder from aleph.sdk.conf import settings -from aleph.sdk.types import GenericMessage +from aleph.sdk.types import GenericMessage, SEVInfo, SEVMeasurement logger = logging.getLogger(__name__) @@ -200,12 +208,15 @@ def bytes_from_hex(hex_string: str) -> bytes: return hex_string -def create_vm_control_payload(vm_id: ItemHash, operation: str) -> Dict[str, str]: +def create_vm_control_payload( + vm_id: ItemHash, operation: str, domain: str, method: str +) -> Dict[str, str]: path = f"/control/machine/{vm_id}/{operation}" payload = { "time": datetime.utcnow().isoformat() + "Z", - "method": "POST", + "method": method.upper(), "path": path, + "domain": domain, } return payload @@ -220,3 +231,150 @@ def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: } ) return signed_operation + + +async def run_in_subprocess( + command: List[str], check: bool = True, stdin_input: Optional[bytes] = None +) -> bytes: + """Run the specified command in a subprocess, returns the stdout of the process.""" + logger.debug(f"command: {' '.join(command)}") + + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate(input=stdin_input) + + if check and process.returncode: + logger.error( + f"Command failed with error code {process.returncode}:\n" + f" stdin = {stdin_input!r}\n" + f" command = {command}\n" + f" stdout = {stderr!r}" + ) + raise subprocess.CalledProcessError( + process.returncode, str(command), stderr.decode() + ) + + return stdout + + +def get_vm_measure(sev_data: SEVMeasurement) -> Tuple[bytes, bytes]: + launch_measure = base64.b64decode(sev_data.launch_measure) + vm_measure = launch_measure[0:32] + nonce = launch_measure[32:48] + return vm_measure, nonce + + +def compute_confidential_measure( + sev_info: SEVInfo, tik: bytes, expected_hash: str, nonce: bytes +) -> hmac.HMAC: + """ + Computes the SEV measurement using the CRN SEV data and local variables like the OVMF firmware hash, + and the session key generated. + """ + + h = hmac.new(tik, digestmod="sha256") + + ## + # calculated per section 6.5.2 + ## + h.update(bytes([0x04])) + h.update(sev_info.api_major.to_bytes(1, byteorder="little")) + h.update(sev_info.api_minor.to_bytes(1, byteorder="little")) + h.update(sev_info.build_id.to_bytes(1, byteorder="little")) + h.update(sev_info.policy.to_bytes(4, byteorder="little")) + + expected_hash_bytes = bytearray.fromhex(expected_hash) + h.update(expected_hash_bytes) + + h.update(nonce) + + return h + + +def make_secret_table(secret: str) -> bytearray: + """ + Makes the disk secret table to be sent to the Confidential CRN + """ + + ## + # Construct the secret table: two guids + 4 byte lengths plus string + # and zero terminator + # + # Secret layout is guid, len (4 bytes), data + # with len being the length from start of guid to end of data + # + # The table header covers the entire table then each entry covers + # only its local data + # + # our current table has the header guid with total table length + # followed by the secret guid with the zero terminated secret + ## + + # total length of table: header plus one entry with trailing \0 + length = 16 + 4 + 16 + 4 + len(secret) + 1 + # SEV-ES requires rounding to 16 + length = (length + 15) & ~15 + secret_table = bytearray(length) + + secret_table[0:16] = UUID("{1e74f542-71dd-4d66-963e-ef4287ff173b}").bytes_le + secret_table[16:20] = len(secret_table).to_bytes(4, byteorder="little") + secret_table[20:36] = UUID("{736869e5-84f0-4973-92ec-06879ce3da0b}").bytes_le + secret_table[36:40] = (16 + 4 + len(secret) + 1).to_bytes(4, byteorder="little") + secret_table[40 : 40 + len(secret)] = secret.encode() + + return secret_table + + +def encrypt_secret_table(secret_table: bytes, tek: bytes, iv: bytes) -> bytes: + """Encrypt the secret table with the TEK in CTR mode using a random IV""" + + # Initialize the cipher with AES algorithm and CTR mode + cipher = Cipher(algorithms.AES(tek), modes.CTR(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Encrypt the secret table + encrypted_secret = encryptor.update(secret_table) + encryptor.finalize() + + return encrypted_secret + + +def make_packet_header( + vm_measure: bytes, + encrypted_secret_table: bytes, + secret_table_size: int, + tik: bytes, + iv: bytes, +) -> bytearray: + """ + Creates a packet header using the encrypted disk secret table to be sent to the Confidential CRN + """ + + ## + # ultimately needs to be an argument, but there's only + # compressed and no real use case + ## + flags = 0 + + ## + # Table 55. LAUNCH_SECRET Packet Header Buffer + ## + header = bytearray(52) + header[0:4] = flags.to_bytes(4, byteorder="little") + header[4:20] = iv + + h = hmac.new(tik, digestmod="sha256") + h.update(bytes([0x01])) + # FLAGS || IV + h.update(header[0:20]) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(encrypted_secret_table) + h.update(vm_measure) + + header[20:52] = h.digest() + + return header diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 7d213547..491da51a 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Awaitable, Coroutine -from typing import Any, Callable, Dict, Literal, Union +from typing import Any, Callable, Dict, Literal, Optional, Union import cryptography.exceptions import pydantic @@ -49,7 +49,6 @@ class SignedPubKeyPayload(BaseModel): # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} # alg: Literal["ECDSA"] - domain: str address: str expires: str @@ -77,7 +76,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes: return bytes_from_hex(value.decode()) @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: + def check_expiry(cls, values) -> Dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) @@ -104,18 +103,18 @@ def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: @property def content(self) -> SignedPubKeyPayload: """Return the content of the header""" - return SignedPubKeyPayload.parse_raw(self.payload) class SignedOperationPayload(BaseModel): time: datetime.datetime method: Union[Literal["POST"], Literal["GET"]] + domain: str path: str # body_sha256: str # disabled since there is no body @validator("time") - def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: + def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: """Check that the time is current and the payload is not a replay attack.""" max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( minutes=2 @@ -123,14 +122,11 @@ def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: max_future = datetime.datetime.now( tz=datetime.timezone.utc ) + datetime.timedelta(minutes=2) - - if value < max_past: + if v < max_past: raise ValueError("Time is too far in the past") - - if value > max_future: + if v > max_future: raise ValueError("Time is too far in the future") - - return value + return v class SignedOperation(BaseModel): @@ -152,12 +148,10 @@ def signature_must_be_hex(cls, value: str) -> bytes: raise error @validator("payload") - def payload_must_be_hex(cls, value: bytes) -> bytes: + def payload_must_be_hex(cls, v) -> bytes: """Convert the payload from hexadecimal to bytes""" - - v = bytes_from_hex(value.decode()) + v = bytes.fromhex(v.decode()) _ = SignedOperationPayload.parse_raw(v) - return v @property @@ -197,7 +191,6 @@ def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: if str(err.exc) == "Invalid signature": raise web.HTTPUnauthorized(reason="Invalid signature") from errors - else: raise errors @@ -207,13 +200,10 @@ def get_signed_operation(request: web.Request) -> SignedOperation: try: signed_operation = request.headers["X-SignedOperation"] return SignedOperation.parse_raw(signed_operation) - except KeyError as error: raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error - except json.JSONDecodeError as error: raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error - except ValidationError as error: logger.debug(f"Invalid X-SignedOperation fields: {error}") raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error @@ -239,14 +229,16 @@ def verify_signed_operation( raise web.HTTPUnauthorized(reason="Signature could not verified") -async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) -> str: +async def authenticate_jwk( + request: web.Request, domain_name: Optional[str] = DOMAIN_NAME +) -> str: """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) - if signed_pubkey.content.domain != domain_name: + if signed_operation.content.domain != domain_name: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") @@ -255,36 +247,31 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) f"Invalid path '{signed_operation.content.path}' != '{request.path}'" ) raise web.HTTPUnauthorized(reason="Invalid path") - if signed_operation.content.method != request.method: logger.debug( f"Invalid method '{signed_operation.content.method}' != '{request.method}'" ) raise web.HTTPUnauthorized(reason="Invalid method") - return verify_signed_operation(signed_operation, signed_pubkey) async def authenticate_websocket_message( - message, domain_name: str = DOMAIN_NAME + message, domain_name: Optional[str] = DOMAIN_NAME ) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) - - if signed_pubkey.content.domain != domain_name: + if signed_operation.content.domain != domain_name: logger.debug( f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") - return verify_signed_operation(signed_operation, signed_pubkey) def require_jwk_authentication( handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: - @functools.wraps(handler) async def wrapper(request): try: @@ -296,6 +283,7 @@ async def wrapper(request): logging.exception(e) raise + # authenticated_sender is the authenticted wallet address of the requester (as a string) response = await handler(request, authenticated_sender) return response diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vm_client.py similarity index 93% rename from tests/unit/test_vmclient.py rename to tests/unit/test_vm_client.py index bc201472..7cc9a2c3 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vm_client.py @@ -1,4 +1,3 @@ -import json from urllib.parse import urlparse import aiohttp @@ -9,7 +8,7 @@ from yarl import URL from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client.vmclient import VmClient +from aleph.sdk.client.vm_client import VmClient from .aleph_vm_authentication import ( SignedOperation, @@ -202,7 +201,9 @@ async def test_authenticate_jwk(aiohttp_client): vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") async def test_authenticate_route(request): - address = await authenticate_jwk(request, domain_name=urlparse(node_url).netloc) + address = await authenticate_jwk( + request, domain_name=urlparse(node_url).hostname + ) assert vm_client.account.get_address() == address return web.Response(text="ok") @@ -222,7 +223,7 @@ async def test_authenticate_route(request): ) status_code, response_text = await vm_client.stop_instance(vm_id) - assert status_code == 200 + assert status_code == 200, response_text assert response_text == "ok" await vm_client.session.close() @@ -239,16 +240,13 @@ async def websocket_handler(request): first_message = await ws.receive_json() credentials = first_message["auth"] - address = await authenticate_websocket_message( - { - "X-SignedPubKey": json.loads(credentials["X-SignedPubKey"]), - "X-SignedOperation": json.loads(credentials["X-SignedOperation"]), - }, - domain_name=urlparse(node_url).netloc, + sender_address = await authenticate_websocket_message( + credentials, + domain_name=urlparse(node_url).hostname, ) - assert vm_client.account.get_address() == address - await ws.send_str(address) + assert vm_client.account.get_address() == sender_address + await ws.send_str(sender_address) return ws @@ -268,6 +266,7 @@ async def websocket_handler(request): ) valid = False + async for address in vm_client.get_logs(vm_id): assert address == vm_client.account.get_address() valid = True @@ -290,7 +289,7 @@ async def test_vm_client_generate_correct_authentication_headers(): session=aiohttp.ClientSession(), ) - path, headers = await vm_client._generate_header(vm_id, "reboot") + path, headers = await vm_client._generate_header(vm_id, "reboot", method="post") signed_pubkey = SignedPubKeyHeader.parse_raw(headers["X-SignedPubKey"]) signed_operation = SignedOperation.parse_raw(headers["X-SignedOperation"]) address = verify_signed_operation(signed_operation, signed_pubkey) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py new file mode 100644 index 00000000..832871ff --- /dev/null +++ b/tests/unit/test_vm_confidential_client.py @@ -0,0 +1,216 @@ +import tempfile +from pathlib import Path +from unittest import mock +from unittest.mock import patch + +import aiohttp +import pytest +from aioresponses import aioresponses +from aleph_message.models import ItemHash + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_confidential_client import VmConfidentialClient + + +@pytest.mark.asyncio +async def test_perform_confidential_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/test" + + with aioresponses() as m: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + response_text = await vm_client.perform_confidential_operation(vm_id, operation) + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_initialize_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/initialize" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_bytes = Path(tmp_file.name).read_bytes() + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + tmp_file_path = Path(tmp_file.name) + response_text = await vm_client.initialize( + vm_id, session=tmp_file_path, godh=tmp_file_path + ) + assert ( + response_text == '"mock_response_text"' + ) # ' ' cause by aioresponses + m.assert_called_once_with( + url, + method="POST", + data={ + "session": tmp_file_bytes, + "godh": tmp_file_bytes, + }, + json=None, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_measurement_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/measurement" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.get( + url, + status=200, + payload=dict( + { + "sev_info": { + "enabled": True, + "api_major": 0, + "api_minor": 0, + "build_id": 0, + "policy": 0, + "state": "", + "handle": 0, + }, + "launch_measure": "test_measure", + } + ), + ) + measurement = await vm_client.measurement(vm_id) + assert measurement.launch_measure == "test_measure" + m.assert_called_once_with( + url, + method="GET", + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_inject_secret_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/inject_secret" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + test_secret = "test_secret" + packet_header = "test_packet_header" + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + response_text = await vm_client.inject_secret( + vm_id, secret=test_secret, packet_header=packet_header + ) + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + json={ + "secret": test_secret, + "packet_header": packet_header, + }, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_create_session_command(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + node_url = "http://localhost" + sevctl_path = Path("/usr/bin/sevctl") + certificates_path = Path("/") + policy = 1 + + with mock.patch( + "aleph.sdk.client.vm_confidential_client.run_in_subprocess", + return_value=True, + ) as export_mock: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=sevctl_path, + node_url=node_url, + session=aiohttp.ClientSession(), + ) + _ = await vm_client.create_session(vm_id, certificates_path, policy) + export_mock.assert_called_once_with( + [ + str(sevctl_path), + "session", + "--name", + str(vm_id), + str(certificates_path), + str(policy), + ], + check=True, + )