From 247dbfcb01c0d2ab8bbe9c4b11854fd94044fff6 Mon Sep 17 00:00:00 2001 From: nesitor Date: Thu, 4 Jul 2024 22:32:19 +0200 Subject: [PATCH] Implement `VmConfidentialClient` class (#138) * Problem: A user cannot initialize an already created confidential VM. Solution: Implement `VmConfidentialClient` class to be able to initialize and interact with confidential VMs. * Problem: Auth was not working Corrections: * Measurement type returned was missing field needed for validation of measurements * Port number was not handled correctly in authentifaction * Adapt to new auth protocol where domain is moved to the operation field (While keeping compat with the old format) * Get measurement was not working since signed with the wrong method * inject_secret was not sending a json * Websocked auth was sending a twice serialized json * update 'vendorized' aleph-vm auth file from source Co-authored-by: Olivier Le Thanh Duong --- pyproject.toml | 1 + .../sdk/client/{vmclient.py => vm_client.py} | 35 ++- .../sdk/client/vm_confidential_client.py | 216 ++++++++++++++++++ src/aleph/sdk/types.py | 25 ++ src/aleph/sdk/utils.py | 164 ++++++++++++- tests/unit/aleph_vm_authentication.py | 46 ++-- .../{test_vmclient.py => test_vm_client.py} | 25 +- tests/unit/test_vm_confidential_client.py | 216 ++++++++++++++++++ 8 files changed, 673 insertions(+), 55 deletions(-) rename src/aleph/sdk/client/{vmclient.py => vm_client.py} (83%) create mode 100644 src/aleph/sdk/client/vm_confidential_client.py rename tests/unit/{test_vmclient.py => test_vm_client.py} (93%) create mode 100644 tests/unit/test_vm_confidential_client.py diff --git a/pyproject.toml b/pyproject.toml index b52efe66..1070a7f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "python-magic", "typer", "typing_extensions", + "aioresponses>=0.7.6" ] [project.optional-dependencies] diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vm_client.py similarity index 83% rename from src/aleph/sdk/client/vmclient.py rename to src/aleph/sdk/client/vm_client.py index 52e9d152..4092851d 100644 --- a/src/aleph/sdk/client/vmclient.py +++ b/src/aleph/sdk/client/vm_client.py @@ -44,7 +44,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", - "domain": urlparse(self.node_url).netloc, + "domain": self.node_domain, "address": self.account.get_address(), "expires": ( datetime.datetime.utcnow() + datetime.timedelta(days=1) @@ -65,14 +65,16 @@ async def _generate_pubkey_signature_header(self) -> str: "sender": self.account.get_address(), "payload": pubkey_payload, "signature": pubkey_signature, - "content": {"domain": urlparse(self.node_url).netloc}, + "content": {"domain": self.node_domain}, } ) async def _generate_header( - self, vm_id: ItemHash, operation: str + self, vm_id: ItemHash, operation: str, method: str ) -> Tuple[str, Dict[str, str]]: - payload = create_vm_control_payload(vm_id, operation) + payload = create_vm_control_payload( + vm_id, operation, domain=self.node_domain, method=method + ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) if not self.pubkey_signature_header: @@ -88,18 +90,29 @@ async def _generate_header( path = payload["path"] return f"{self.node_url}{path}", headers + @property + def node_domain(self) -> str: + domain = urlparse(self.node_url).hostname + if not domain: + raise Exception("Could not parse node domain") + return domain + async def perform_operation( - self, vm_id: ItemHash, operation: str + self, vm_id: ItemHash, operation: str, method: str = "POST" ) -> Tuple[Optional[int], str]: if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() ) - url, header = await self._generate_header(vm_id=vm_id, operation=operation) + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method=method + ) try: - async with self.session.post(url, headers=header) as response: + async with self.session.request( + method=method, url=url, headers=header + ) as response: response_text = await response.text() return response.status, response_text @@ -113,7 +126,9 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: await self._generate_pubkey_signature_header() ) - payload = create_vm_control_payload(vm_id, "stream_logs") + payload = create_vm_control_payload( + vm_id, "stream_logs", method="get", domain=self.node_domain + ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) path = payload["path"] ws_url = f"{self.node_url}{path}" @@ -121,8 +136,8 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: async with self.session.ws_connect(ws_url) as ws: auth_message = { "auth": { - "X-SignedPubKey": self.pubkey_signature_header, - "X-SignedOperation": signed_operation, + "X-SignedPubKey": json.loads(self.pubkey_signature_header), + "X-SignedOperation": json.loads(signed_operation), } } await ws.send_json(auth_message) diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py new file mode 100644 index 00000000..a100de8c --- /dev/null +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -0,0 +1,216 @@ +import base64 +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.client.vm_client import VmClient +from aleph.sdk.types import Account, SEVMeasurement +from aleph.sdk.utils import ( + compute_confidential_measure, + encrypt_secret_table, + get_vm_measure, + make_packet_header, + make_secret_table, + run_in_subprocess, +) + +logger = logging.getLogger(__name__) + + +class VmConfidentialClient(VmClient): + sevctl_path: Path + + def __init__( + self, + account: Account, + sevctl_path: Path, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + super().__init__(account, node_url, session) + self.sevctl_path = sevctl_path + + async def get_certificates(self) -> Tuple[Optional[int], str]: + """ + Get platform confidential certificate + """ + + url = f"{self.node_url}/about/certificates" + try: + async with self.session.get(url) as response: + data = await response.read() + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(data) + return response.status, tmp_file.name + + except aiohttp.ClientError as e: + logger.error( + f"HTTP error getting node certificates on {self.node_url}: {str(e)}" + ) + return None, str(e) + + async def create_session( + self, vm_id: ItemHash, certificate_path: Path, policy: int + ) -> Path: + """ + Create new confidential session + """ + + current_path = Path().cwd() + args = [ + "session", + "--name", + str(vm_id), + str(certificate_path), + str(policy), + ] + try: + # TODO: Check command result + await self.sevctl_cmd(*args) + return current_path + except Exception as e: + raise ValueError(f"Session creation have failed, reason: {str(e)}") + + async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str: + """ + Initialize Confidential VM negociation passing the needed session files + """ + + session_file = session.read_bytes() + godh_file = godh.read_bytes() + params = { + "session": session_file, + "godh": godh_file, + } + return await self.perform_confidential_operation( + vm_id, "confidential/initialize", params=params + ) + + async def measurement(self, vm_id: ItemHash) -> SEVMeasurement: + """ + Fetch VM confidential measurement + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + status, text = await self.perform_operation( + vm_id, "confidential/measurement", method="GET" + ) + sev_mesurement = SEVMeasurement.parse_raw(text) + return sev_mesurement + + async def validate_measure( + self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str + ) -> bool: + """ + Validate VM confidential measurement + """ + + tik = tik_path.read_bytes() + vm_measure, nonce = get_vm_measure(sev_data) + + expected_measure = compute_confidential_measure( + sev_info=sev_data.sev_info, + tik=tik, + expected_hash=firmware_hash, + nonce=nonce, + ).digest() + return expected_measure == vm_measure + + async def build_secret( + self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str + ) -> Tuple[str, str]: + """ + Build disk secret to be injected on the confidential VM + """ + + tek = tek_path.read_bytes() + tik = tik_path.read_bytes() + + vm_measure, _ = get_vm_measure(sev_data) + + iv = os.urandom(16) + secret_table = make_secret_table(secret) + encrypted_secret_table = encrypt_secret_table( + secret_table=secret_table, tek=tek, iv=iv + ) + + packet_header = make_packet_header( + vm_measure=vm_measure, + encrypted_secret_table=encrypted_secret_table, + secret_table_size=len(secret_table), + tik=tik, + iv=iv, + ) + + encoded_packet_header = base64.b64encode(packet_header).decode() + encoded_secret = base64.b64encode(encrypted_secret_table).decode() + + return encoded_packet_header, encoded_secret + + async def inject_secret( + self, vm_id: ItemHash, packet_header: str, secret: str + ) -> Dict: + """ + Send the secret by the encrypted channel to boot up the VM + """ + + params = { + "packet_header": packet_header, + "secret": secret, + } + text = await self.perform_confidential_operation( + vm_id, "confidential/inject_secret", json=params + ) + + return json.loads(text) + + async def perform_confidential_operation( + self, + vm_id: ItemHash, + operation: str, + params: Optional[Dict[str, Any]] = None, + json=None, + ) -> str: + """ + Send confidential operations to the CRN passing the auth headers on each request + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method="post" + ) + + try: + async with self.session.post( + url, headers=header, data=params, json=json + ) as response: + response.raise_for_status() + response_text = await response.text() + return response_text + + except aiohttp.ClientError as e: + raise ValueError(f"HTTP error during operation {operation}: {str(e)}") + + async def sevctl_cmd(self, *args) -> bytes: + """ + Execute `sevctl` command with given arguments + """ + + return await run_in_subprocess( + [str(self.sevctl_path), *args], + check=True, + ) diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 71bc2b53..cf9e6fa8 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -2,6 +2,8 @@ from enum import Enum from typing import Dict, Protocol, TypeVar +from pydantic import BaseModel + __all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") from aleph_message.models import AlephMessage @@ -39,3 +41,26 @@ async def sign_raw(self, buffer: bytes) -> bytes: ... GenericMessage = TypeVar("GenericMessage", bound=AlephMessage) + + +class SEVInfo(BaseModel): + """ + An AMD SEV platform information. + """ + + enabled: bool + api_major: int + api_minor: int + build_id: int + policy: int + state: str + handle: int + + +class SEVMeasurement(BaseModel): + """ + A SEV measurement data get from Qemu measurement. + """ + + sev_info: SEVInfo + launch_measure: str diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 2d1b30c7..5c641d5c 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,8 +1,12 @@ +import asyncio +import base64 import errno import hashlib +import hmac import json import logging import os +import subprocess from datetime import date, datetime, time from enum import Enum from pathlib import Path @@ -11,6 +15,7 @@ Any, Dict, Iterable, + List, Mapping, Optional, Protocol, @@ -20,16 +25,19 @@ Union, get_args, ) +from uuid import UUID from zipfile import BadZipFile, ZipFile from aleph_message.models import ItemHash, MessageType from aleph_message.models.execution.program import Encoding from aleph_message.models.execution.volume import MachineVolume +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from jwcrypto.jwa import JWA from pydantic.json import pydantic_encoder from aleph.sdk.conf import settings -from aleph.sdk.types import GenericMessage +from aleph.sdk.types import GenericMessage, SEVInfo, SEVMeasurement logger = logging.getLogger(__name__) @@ -200,12 +208,15 @@ def bytes_from_hex(hex_string: str) -> bytes: return hex_string -def create_vm_control_payload(vm_id: ItemHash, operation: str) -> Dict[str, str]: +def create_vm_control_payload( + vm_id: ItemHash, operation: str, domain: str, method: str +) -> Dict[str, str]: path = f"/control/machine/{vm_id}/{operation}" payload = { "time": datetime.utcnow().isoformat() + "Z", - "method": "POST", + "method": method.upper(), "path": path, + "domain": domain, } return payload @@ -220,3 +231,150 @@ def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: } ) return signed_operation + + +async def run_in_subprocess( + command: List[str], check: bool = True, stdin_input: Optional[bytes] = None +) -> bytes: + """Run the specified command in a subprocess, returns the stdout of the process.""" + logger.debug(f"command: {' '.join(command)}") + + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate(input=stdin_input) + + if check and process.returncode: + logger.error( + f"Command failed with error code {process.returncode}:\n" + f" stdin = {stdin_input!r}\n" + f" command = {command}\n" + f" stdout = {stderr!r}" + ) + raise subprocess.CalledProcessError( + process.returncode, str(command), stderr.decode() + ) + + return stdout + + +def get_vm_measure(sev_data: SEVMeasurement) -> Tuple[bytes, bytes]: + launch_measure = base64.b64decode(sev_data.launch_measure) + vm_measure = launch_measure[0:32] + nonce = launch_measure[32:48] + return vm_measure, nonce + + +def compute_confidential_measure( + sev_info: SEVInfo, tik: bytes, expected_hash: str, nonce: bytes +) -> hmac.HMAC: + """ + Computes the SEV measurement using the CRN SEV data and local variables like the OVMF firmware hash, + and the session key generated. + """ + + h = hmac.new(tik, digestmod="sha256") + + ## + # calculated per section 6.5.2 + ## + h.update(bytes([0x04])) + h.update(sev_info.api_major.to_bytes(1, byteorder="little")) + h.update(sev_info.api_minor.to_bytes(1, byteorder="little")) + h.update(sev_info.build_id.to_bytes(1, byteorder="little")) + h.update(sev_info.policy.to_bytes(4, byteorder="little")) + + expected_hash_bytes = bytearray.fromhex(expected_hash) + h.update(expected_hash_bytes) + + h.update(nonce) + + return h + + +def make_secret_table(secret: str) -> bytearray: + """ + Makes the disk secret table to be sent to the Confidential CRN + """ + + ## + # Construct the secret table: two guids + 4 byte lengths plus string + # and zero terminator + # + # Secret layout is guid, len (4 bytes), data + # with len being the length from start of guid to end of data + # + # The table header covers the entire table then each entry covers + # only its local data + # + # our current table has the header guid with total table length + # followed by the secret guid with the zero terminated secret + ## + + # total length of table: header plus one entry with trailing \0 + length = 16 + 4 + 16 + 4 + len(secret) + 1 + # SEV-ES requires rounding to 16 + length = (length + 15) & ~15 + secret_table = bytearray(length) + + secret_table[0:16] = UUID("{1e74f542-71dd-4d66-963e-ef4287ff173b}").bytes_le + secret_table[16:20] = len(secret_table).to_bytes(4, byteorder="little") + secret_table[20:36] = UUID("{736869e5-84f0-4973-92ec-06879ce3da0b}").bytes_le + secret_table[36:40] = (16 + 4 + len(secret) + 1).to_bytes(4, byteorder="little") + secret_table[40 : 40 + len(secret)] = secret.encode() + + return secret_table + + +def encrypt_secret_table(secret_table: bytes, tek: bytes, iv: bytes) -> bytes: + """Encrypt the secret table with the TEK in CTR mode using a random IV""" + + # Initialize the cipher with AES algorithm and CTR mode + cipher = Cipher(algorithms.AES(tek), modes.CTR(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Encrypt the secret table + encrypted_secret = encryptor.update(secret_table) + encryptor.finalize() + + return encrypted_secret + + +def make_packet_header( + vm_measure: bytes, + encrypted_secret_table: bytes, + secret_table_size: int, + tik: bytes, + iv: bytes, +) -> bytearray: + """ + Creates a packet header using the encrypted disk secret table to be sent to the Confidential CRN + """ + + ## + # ultimately needs to be an argument, but there's only + # compressed and no real use case + ## + flags = 0 + + ## + # Table 55. LAUNCH_SECRET Packet Header Buffer + ## + header = bytearray(52) + header[0:4] = flags.to_bytes(4, byteorder="little") + header[4:20] = iv + + h = hmac.new(tik, digestmod="sha256") + h.update(bytes([0x01])) + # FLAGS || IV + h.update(header[0:20]) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(encrypted_secret_table) + h.update(vm_measure) + + header[20:52] = h.digest() + + return header diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 7d213547..491da51a 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Awaitable, Coroutine -from typing import Any, Callable, Dict, Literal, Union +from typing import Any, Callable, Dict, Literal, Optional, Union import cryptography.exceptions import pydantic @@ -49,7 +49,6 @@ class SignedPubKeyPayload(BaseModel): # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} # alg: Literal["ECDSA"] - domain: str address: str expires: str @@ -77,7 +76,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes: return bytes_from_hex(value.decode()) @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: + def check_expiry(cls, values) -> Dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) @@ -104,18 +103,18 @@ def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: @property def content(self) -> SignedPubKeyPayload: """Return the content of the header""" - return SignedPubKeyPayload.parse_raw(self.payload) class SignedOperationPayload(BaseModel): time: datetime.datetime method: Union[Literal["POST"], Literal["GET"]] + domain: str path: str # body_sha256: str # disabled since there is no body @validator("time") - def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: + def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: """Check that the time is current and the payload is not a replay attack.""" max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( minutes=2 @@ -123,14 +122,11 @@ def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: max_future = datetime.datetime.now( tz=datetime.timezone.utc ) + datetime.timedelta(minutes=2) - - if value < max_past: + if v < max_past: raise ValueError("Time is too far in the past") - - if value > max_future: + if v > max_future: raise ValueError("Time is too far in the future") - - return value + return v class SignedOperation(BaseModel): @@ -152,12 +148,10 @@ def signature_must_be_hex(cls, value: str) -> bytes: raise error @validator("payload") - def payload_must_be_hex(cls, value: bytes) -> bytes: + def payload_must_be_hex(cls, v) -> bytes: """Convert the payload from hexadecimal to bytes""" - - v = bytes_from_hex(value.decode()) + v = bytes.fromhex(v.decode()) _ = SignedOperationPayload.parse_raw(v) - return v @property @@ -197,7 +191,6 @@ def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: if str(err.exc) == "Invalid signature": raise web.HTTPUnauthorized(reason="Invalid signature") from errors - else: raise errors @@ -207,13 +200,10 @@ def get_signed_operation(request: web.Request) -> SignedOperation: try: signed_operation = request.headers["X-SignedOperation"] return SignedOperation.parse_raw(signed_operation) - except KeyError as error: raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error - except json.JSONDecodeError as error: raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error - except ValidationError as error: logger.debug(f"Invalid X-SignedOperation fields: {error}") raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error @@ -239,14 +229,16 @@ def verify_signed_operation( raise web.HTTPUnauthorized(reason="Signature could not verified") -async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) -> str: +async def authenticate_jwk( + request: web.Request, domain_name: Optional[str] = DOMAIN_NAME +) -> str: """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) - if signed_pubkey.content.domain != domain_name: + if signed_operation.content.domain != domain_name: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") @@ -255,36 +247,31 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) f"Invalid path '{signed_operation.content.path}' != '{request.path}'" ) raise web.HTTPUnauthorized(reason="Invalid path") - if signed_operation.content.method != request.method: logger.debug( f"Invalid method '{signed_operation.content.method}' != '{request.method}'" ) raise web.HTTPUnauthorized(reason="Invalid method") - return verify_signed_operation(signed_operation, signed_pubkey) async def authenticate_websocket_message( - message, domain_name: str = DOMAIN_NAME + message, domain_name: Optional[str] = DOMAIN_NAME ) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) - - if signed_pubkey.content.domain != domain_name: + if signed_operation.content.domain != domain_name: logger.debug( f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") - return verify_signed_operation(signed_operation, signed_pubkey) def require_jwk_authentication( handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: - @functools.wraps(handler) async def wrapper(request): try: @@ -296,6 +283,7 @@ async def wrapper(request): logging.exception(e) raise + # authenticated_sender is the authenticted wallet address of the requester (as a string) response = await handler(request, authenticated_sender) return response diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vm_client.py similarity index 93% rename from tests/unit/test_vmclient.py rename to tests/unit/test_vm_client.py index bc201472..7cc9a2c3 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vm_client.py @@ -1,4 +1,3 @@ -import json from urllib.parse import urlparse import aiohttp @@ -9,7 +8,7 @@ from yarl import URL from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client.vmclient import VmClient +from aleph.sdk.client.vm_client import VmClient from .aleph_vm_authentication import ( SignedOperation, @@ -202,7 +201,9 @@ async def test_authenticate_jwk(aiohttp_client): vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") async def test_authenticate_route(request): - address = await authenticate_jwk(request, domain_name=urlparse(node_url).netloc) + address = await authenticate_jwk( + request, domain_name=urlparse(node_url).hostname + ) assert vm_client.account.get_address() == address return web.Response(text="ok") @@ -222,7 +223,7 @@ async def test_authenticate_route(request): ) status_code, response_text = await vm_client.stop_instance(vm_id) - assert status_code == 200 + assert status_code == 200, response_text assert response_text == "ok" await vm_client.session.close() @@ -239,16 +240,13 @@ async def websocket_handler(request): first_message = await ws.receive_json() credentials = first_message["auth"] - address = await authenticate_websocket_message( - { - "X-SignedPubKey": json.loads(credentials["X-SignedPubKey"]), - "X-SignedOperation": json.loads(credentials["X-SignedOperation"]), - }, - domain_name=urlparse(node_url).netloc, + sender_address = await authenticate_websocket_message( + credentials, + domain_name=urlparse(node_url).hostname, ) - assert vm_client.account.get_address() == address - await ws.send_str(address) + assert vm_client.account.get_address() == sender_address + await ws.send_str(sender_address) return ws @@ -268,6 +266,7 @@ async def websocket_handler(request): ) valid = False + async for address in vm_client.get_logs(vm_id): assert address == vm_client.account.get_address() valid = True @@ -290,7 +289,7 @@ async def test_vm_client_generate_correct_authentication_headers(): session=aiohttp.ClientSession(), ) - path, headers = await vm_client._generate_header(vm_id, "reboot") + path, headers = await vm_client._generate_header(vm_id, "reboot", method="post") signed_pubkey = SignedPubKeyHeader.parse_raw(headers["X-SignedPubKey"]) signed_operation = SignedOperation.parse_raw(headers["X-SignedOperation"]) address = verify_signed_operation(signed_operation, signed_pubkey) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py new file mode 100644 index 00000000..832871ff --- /dev/null +++ b/tests/unit/test_vm_confidential_client.py @@ -0,0 +1,216 @@ +import tempfile +from pathlib import Path +from unittest import mock +from unittest.mock import patch + +import aiohttp +import pytest +from aioresponses import aioresponses +from aleph_message.models import ItemHash + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_confidential_client import VmConfidentialClient + + +@pytest.mark.asyncio +async def test_perform_confidential_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/test" + + with aioresponses() as m: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + response_text = await vm_client.perform_confidential_operation(vm_id, operation) + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_initialize_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/initialize" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_bytes = Path(tmp_file.name).read_bytes() + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + tmp_file_path = Path(tmp_file.name) + response_text = await vm_client.initialize( + vm_id, session=tmp_file_path, godh=tmp_file_path + ) + assert ( + response_text == '"mock_response_text"' + ) # ' ' cause by aioresponses + m.assert_called_once_with( + url, + method="POST", + data={ + "session": tmp_file_bytes, + "godh": tmp_file_bytes, + }, + json=None, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_measurement_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/measurement" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.get( + url, + status=200, + payload=dict( + { + "sev_info": { + "enabled": True, + "api_major": 0, + "api_minor": 0, + "build_id": 0, + "policy": 0, + "state": "", + "handle": 0, + }, + "launch_measure": "test_measure", + } + ), + ) + measurement = await vm_client.measurement(vm_id) + assert measurement.launch_measure == "test_measure" + m.assert_called_once_with( + url, + method="GET", + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_inject_secret_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/inject_secret" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + test_secret = "test_secret" + packet_header = "test_packet_header" + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + response_text = await vm_client.inject_secret( + vm_id, secret=test_secret, packet_header=packet_header + ) + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + json={ + "secret": test_secret, + "packet_header": packet_header, + }, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_create_session_command(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + node_url = "http://localhost" + sevctl_path = Path("/usr/bin/sevctl") + certificates_path = Path("/") + policy = 1 + + with mock.patch( + "aleph.sdk.client.vm_confidential_client.run_in_subprocess", + return_value=True, + ) as export_mock: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=sevctl_path, + node_url=node_url, + session=aiohttp.ClientSession(), + ) + _ = await vm_client.create_session(vm_id, certificates_path, policy) + export_mock.assert_called_once_with( + [ + str(sevctl_path), + "session", + "--name", + str(vm_id), + str(certificates_path), + str(policy), + ], + check=True, + )