Skip to content

Commit

Permalink
Implement VmConfidentialClient class (#138)
Browse files Browse the repository at this point in the history
* Problem: A user cannot initialize an already created confidential VM.

Solution: Implement `VmConfidentialClient` class to be able to initialize and interact with confidential VMs.

* Problem: Auth was not working

Corrections:
 *  Measurement type returned was missing field needed for validation of measurements
 * Port number was not handled correctly in authentifaction
 * Adapt to new auth protocol where domain is moved to the operation field (While keeping compat with the old format)
 * Get measurement was not working since signed with the wrong method
 * inject_secret was not sending a json
 * Websocked auth was sending a twice serialized json
 * update 'vendorized' aleph-vm auth file from source

Co-authored-by: Olivier Le Thanh Duong <[email protected]>
  • Loading branch information
nesitor and olethanh authored Jul 4, 2024
1 parent d9b1892 commit 247dbfc
Show file tree
Hide file tree
Showing 8 changed files with 673 additions and 55 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"python-magic",
"typer",
"typing_extensions",
"aioresponses>=0.7.6"
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]:
return {
"pubkey": json.loads(self.ephemeral_key.export_public()),
"alg": "ECDSA",
"domain": urlparse(self.node_url).netloc,
"domain": self.node_domain,
"address": self.account.get_address(),
"expires": (
datetime.datetime.utcnow() + datetime.timedelta(days=1)
Expand All @@ -65,14 +65,16 @@ async def _generate_pubkey_signature_header(self) -> str:
"sender": self.account.get_address(),
"payload": pubkey_payload,
"signature": pubkey_signature,
"content": {"domain": urlparse(self.node_url).netloc},
"content": {"domain": self.node_domain},
}
)

async def _generate_header(
self, vm_id: ItemHash, operation: str
self, vm_id: ItemHash, operation: str, method: str
) -> Tuple[str, Dict[str, str]]:
payload = create_vm_control_payload(vm_id, operation)
payload = create_vm_control_payload(
vm_id, operation, domain=self.node_domain, method=method
)
signed_operation = sign_vm_control_payload(payload, self.ephemeral_key)

if not self.pubkey_signature_header:
Expand All @@ -88,18 +90,29 @@ async def _generate_header(
path = payload["path"]
return f"{self.node_url}{path}", headers

@property
def node_domain(self) -> str:
domain = urlparse(self.node_url).hostname
if not domain:
raise Exception("Could not parse node domain")
return domain

async def perform_operation(
self, vm_id: ItemHash, operation: str
self, vm_id: ItemHash, operation: str, method: str = "POST"
) -> Tuple[Optional[int], str]:
if not self.pubkey_signature_header:
self.pubkey_signature_header = (
await self._generate_pubkey_signature_header()
)

url, header = await self._generate_header(vm_id=vm_id, operation=operation)
url, header = await self._generate_header(
vm_id=vm_id, operation=operation, method=method
)

try:
async with self.session.post(url, headers=header) as response:
async with self.session.request(
method=method, url=url, headers=header
) as response:
response_text = await response.text()
return response.status, response_text

Expand All @@ -113,16 +126,18 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]:
await self._generate_pubkey_signature_header()
)

payload = create_vm_control_payload(vm_id, "stream_logs")
payload = create_vm_control_payload(
vm_id, "stream_logs", method="get", domain=self.node_domain
)
signed_operation = sign_vm_control_payload(payload, self.ephemeral_key)
path = payload["path"]
ws_url = f"{self.node_url}{path}"

async with self.session.ws_connect(ws_url) as ws:
auth_message = {
"auth": {
"X-SignedPubKey": self.pubkey_signature_header,
"X-SignedOperation": signed_operation,
"X-SignedPubKey": json.loads(self.pubkey_signature_header),
"X-SignedOperation": json.loads(signed_operation),
}
}
await ws.send_json(auth_message)
Expand Down
216 changes: 216 additions & 0 deletions src/aleph/sdk/client/vm_confidential_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import base64
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import aiohttp
from aleph_message.models import ItemHash

from aleph.sdk.client.vm_client import VmClient
from aleph.sdk.types import Account, SEVMeasurement
from aleph.sdk.utils import (
compute_confidential_measure,
encrypt_secret_table,
get_vm_measure,
make_packet_header,
make_secret_table,
run_in_subprocess,
)

logger = logging.getLogger(__name__)


class VmConfidentialClient(VmClient):
sevctl_path: Path

def __init__(
self,
account: Account,
sevctl_path: Path,
node_url: str = "",
session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(account, node_url, session)
self.sevctl_path = sevctl_path

async def get_certificates(self) -> Tuple[Optional[int], str]:
"""
Get platform confidential certificate
"""

url = f"{self.node_url}/about/certificates"
try:
async with self.session.get(url) as response:
data = await response.read()
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(data)
return response.status, tmp_file.name

except aiohttp.ClientError as e:
logger.error(
f"HTTP error getting node certificates on {self.node_url}: {str(e)}"
)
return None, str(e)

async def create_session(
self, vm_id: ItemHash, certificate_path: Path, policy: int
) -> Path:
"""
Create new confidential session
"""

current_path = Path().cwd()
args = [
"session",
"--name",
str(vm_id),
str(certificate_path),
str(policy),
]
try:
# TODO: Check command result
await self.sevctl_cmd(*args)
return current_path
except Exception as e:
raise ValueError(f"Session creation have failed, reason: {str(e)}")

async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str:
"""
Initialize Confidential VM negociation passing the needed session files
"""

session_file = session.read_bytes()
godh_file = godh.read_bytes()
params = {
"session": session_file,
"godh": godh_file,
}
return await self.perform_confidential_operation(
vm_id, "confidential/initialize", params=params
)

async def measurement(self, vm_id: ItemHash) -> SEVMeasurement:
"""
Fetch VM confidential measurement
"""

if not self.pubkey_signature_header:
self.pubkey_signature_header = (
await self._generate_pubkey_signature_header()
)

status, text = await self.perform_operation(
vm_id, "confidential/measurement", method="GET"
)
sev_mesurement = SEVMeasurement.parse_raw(text)
return sev_mesurement

async def validate_measure(
self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str
) -> bool:
"""
Validate VM confidential measurement
"""

tik = tik_path.read_bytes()
vm_measure, nonce = get_vm_measure(sev_data)

expected_measure = compute_confidential_measure(
sev_info=sev_data.sev_info,
tik=tik,
expected_hash=firmware_hash,
nonce=nonce,
).digest()
return expected_measure == vm_measure

async def build_secret(
self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str
) -> Tuple[str, str]:
"""
Build disk secret to be injected on the confidential VM
"""

tek = tek_path.read_bytes()
tik = tik_path.read_bytes()

vm_measure, _ = get_vm_measure(sev_data)

iv = os.urandom(16)
secret_table = make_secret_table(secret)
encrypted_secret_table = encrypt_secret_table(
secret_table=secret_table, tek=tek, iv=iv
)

packet_header = make_packet_header(
vm_measure=vm_measure,
encrypted_secret_table=encrypted_secret_table,
secret_table_size=len(secret_table),
tik=tik,
iv=iv,
)

encoded_packet_header = base64.b64encode(packet_header).decode()
encoded_secret = base64.b64encode(encrypted_secret_table).decode()

return encoded_packet_header, encoded_secret

async def inject_secret(
self, vm_id: ItemHash, packet_header: str, secret: str
) -> Dict:
"""
Send the secret by the encrypted channel to boot up the VM
"""

params = {
"packet_header": packet_header,
"secret": secret,
}
text = await self.perform_confidential_operation(
vm_id, "confidential/inject_secret", json=params
)

return json.loads(text)

async def perform_confidential_operation(
self,
vm_id: ItemHash,
operation: str,
params: Optional[Dict[str, Any]] = None,
json=None,
) -> str:
"""
Send confidential operations to the CRN passing the auth headers on each request
"""

if not self.pubkey_signature_header:
self.pubkey_signature_header = (
await self._generate_pubkey_signature_header()
)

url, header = await self._generate_header(
vm_id=vm_id, operation=operation, method="post"
)

try:
async with self.session.post(
url, headers=header, data=params, json=json
) as response:
response.raise_for_status()
response_text = await response.text()
return response_text

except aiohttp.ClientError as e:
raise ValueError(f"HTTP error during operation {operation}: {str(e)}")

async def sevctl_cmd(self, *args) -> bytes:
"""
Execute `sevctl` command with given arguments
"""

return await run_in_subprocess(
[str(self.sevctl_path), *args],
check=True,
)
25 changes: 25 additions & 0 deletions src/aleph/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
from typing import Dict, Protocol, TypeVar

from pydantic import BaseModel

__all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage")

from aleph_message.models import AlephMessage
Expand Down Expand Up @@ -39,3 +41,26 @@ async def sign_raw(self, buffer: bytes) -> bytes: ...


GenericMessage = TypeVar("GenericMessage", bound=AlephMessage)


class SEVInfo(BaseModel):
"""
An AMD SEV platform information.
"""

enabled: bool
api_major: int
api_minor: int
build_id: int
policy: int
state: str
handle: int


class SEVMeasurement(BaseModel):
"""
A SEV measurement data get from Qemu measurement.
"""

sev_info: SEVInfo
launch_measure: str
Loading

0 comments on commit 247dbfc

Please sign in to comment.