diff --git a/src/aleph/vm/controllers/interface.py b/src/aleph/vm/controllers/interface.py index efce726e2..d5a290173 100644 --- a/src/aleph/vm/controllers/interface.py +++ b/src/aleph/vm/controllers/interface.py @@ -10,7 +10,7 @@ from aleph.vm.controllers.firecracker.snapshots import CompressedDiskVolumeSnapshot from aleph.vm.network.interfaces import TapInterface -from aleph.vm.utils.logs import make_logs_queue +from aleph.vm.utils.logs import get_past_vm_logs, make_logs_queue logger = logging.getLogger(__name__) @@ -118,3 +118,6 @@ def _journal_stdout_name(self) -> str: @property def _journal_stderr_name(self) -> str: return f"vm-{self.vm_hash}-stderr" + + def past_logs(self): + yield from get_past_vm_logs(self._journal_stdout_name, self._journal_stderr_name) diff --git a/src/aleph/vm/controllers/qemu/instance.py b/src/aleph/vm/controllers/qemu/instance.py index 2353162b6..81dea8ff3 100644 --- a/src/aleph/vm/controllers/qemu/instance.py +++ b/src/aleph/vm/controllers/qemu/instance.py @@ -247,7 +247,3 @@ async def teardown(self): if self.tap_interface: await self.tap_interface.delete() await self.stop_guest_api() - - def print_logs(self) -> None: - """Print logs to our output for debugging""" - queue = self.get_log_queue() diff --git a/src/aleph/vm/orchestrator/supervisor.py b/src/aleph/vm/orchestrator/supervisor.py index 31f7fe42f..a3905cae2 100644 --- a/src/aleph/vm/orchestrator/supervisor.py +++ b/src/aleph/vm/orchestrator/supervisor.py @@ -49,6 +49,7 @@ from .views.operator import ( operate_erase, operate_expire, + operate_logs, operate_reboot, operate_stop, stream_logs, @@ -100,7 +101,8 @@ def setup_webapp(): web.get("/about/config", about_config), # /control APIs are used to control the VMs and access their logs web.post("/control/allocation/notify", notify_allocation), - web.get("/control/machine/{ref}/logs", stream_logs), + web.get("/control/machine/{ref}/stream_logs", stream_logs), + web.get("/control/machine/{ref}/logs", operate_logs), web.post("/control/machine/{ref}/expire", operate_expire), web.post("/control/machine/{ref}/stop", operate_stop), web.post("/control/machine/{ref}/erase", operate_erase), diff --git a/src/aleph/vm/orchestrator/views/operator.py b/src/aleph/vm/orchestrator/views/operator.py index 148ecd092..cd8fbae14 100644 --- a/src/aleph/vm/orchestrator/views/operator.py +++ b/src/aleph/vm/orchestrator/views/operator.py @@ -65,6 +65,7 @@ async def stream_logs(request: web.Request) -> web.StreamResponse: queue = None try: ws = web.WebSocketResponse() + logger.info(f"starting websocket: {request.path}") await ws.prepare(request) try: await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws) @@ -75,6 +76,7 @@ async def stream_logs(request: web.Request) -> web.StreamResponse: while True: log_type, message = await queue.get() assert log_type in ("stdout", "stderr") + logger.debug(message) await ws.send_json({"type": log_type, "message": message}) @@ -87,15 +89,41 @@ async def stream_logs(request: web.Request) -> web.StreamResponse: execution.vm.unregister_queue(queue) +@cors_allow_all +@require_jwk_authentication +async def operate_logs(request: web.Request, authenticated_sender: str) -> web.StreamResponse: + """Logs of a VM (not streaming)""" + vm_hash = get_itemhash_or_400(request.match_info) + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") + + response = web.StreamResponse() + response.headers["Content-Type"] = "text/plain" + await response.prepare(request) + + for entry in execution.vm.past_logs(): + msg = f'{entry["__REALTIME_TIMESTAMP"].isoformat()}> {entry["MESSAGE"]}' + await response.write(msg.encode()) + await response.write_eof() + return response + + async def authenticate_websocket_for_vm_or_403(execution: VmExecution, vm_hash: ItemHash, ws: web.WebSocketResponse): """Authenticate a websocket connection. Web browsers do not allow setting headers in WebSocket requests, so the authentication relies on the first message sent by the client. """ - first_message = await ws.receive_json() + try: + first_message = await ws.receive_json() + except TypeError as error: + logging.exception(error) + raise web.HTTPForbidden(body="Invalid auth package") credentials = first_message["auth"] authenticated_sender = await authenticate_websocket_message(credentials) + if is_sender_authorized(authenticated_sender, execution.message): logger.debug(f"Accepted request to access logs by {authenticated_sender} on {vm_hash}") return True diff --git a/src/aleph/vm/utils/logs.py b/src/aleph/vm/utils/logs.py index d95adbac6..a112cfc0a 100644 --- a/src/aleph/vm/utils/logs.py +++ b/src/aleph/vm/utils/logs.py @@ -1,6 +1,7 @@ import asyncio import logging -from typing import Callable, TypedDict +from datetime import datetime +from typing import Callable, Generator, TypedDict from systemd import journal @@ -10,6 +11,7 @@ class EntryDict(TypedDict): SYSLOG_IDENTIFIER: str MESSAGE: str + __REALTIME_TIMESTAMP: datetime def make_logs_queue(stdout_identifier, stderr_identifier, skip_past=False) -> tuple[asyncio.Queue, Callable[[], None]]: @@ -56,3 +58,25 @@ def do_cancel(): r.close() return queue, do_cancel + + +def get_past_vm_logs(stdout_identifier, stderr_identifier) -> Generator[EntryDict, None, None]: + """Get existing log for the VM identifiers. + + @param stdout_identifier: journald identifier for process stdout + @param stderr_identifier: journald identifier for process stderr + @return: an iterator of log entry + + Works by creating a journald reader, and using `add_reader` to call a callback when + data is available for reading. + + For more information refer to the sd-journal(3) manpage + and systemd.journal module documentation. + """ + r = journal.Reader() + r.add_match(SYSLOG_IDENTIFIER=stdout_identifier) + r.add_match(SYSLOG_IDENTIFIER=stderr_identifier) + + r.seek_head() + for entry in r: + yield entry diff --git a/src/aleph/vm/utils/test_helpers.py b/src/aleph/vm/utils/test_helpers.py new file mode 100644 index 000000000..ecdf4f40b --- /dev/null +++ b/src/aleph/vm/utils/test_helpers.py @@ -0,0 +1,86 @@ +import datetime +import json + +import eth_account.messages +import pytest +from eth_account.datastructures import SignedMessage +from eth_account.signers.local import LocalAccount +from jwcrypto import jwk +from jwcrypto.jwa import JWA + + +@pytest.fixture +def patch_datetime_now(monkeypatch): + """Fixture for patching the datetime.now() and datetime.utcnow() methods + to return a fixed datetime object. + This fixture creates a subclass of `datetime.datetime` called `mydatetime`, + which overrides the `now()` and `utcnow()` class methods to return a fixed + datetime object specified by `FAKE_TIME`. + """ + + class MockDateTime(datetime.datetime): + FAKE_TIME = datetime.datetime(2010, 12, 25, 17, 5, 55) + + @classmethod + def now(cls, tz=None, *args, **kwargs): + return cls.FAKE_TIME.replace(tzinfo=tz) + + @classmethod + def utcnow(cls, *args, **kwargs): + return cls.FAKE_TIME + + monkeypatch.setattr(datetime, "datetime", MockDateTime) + return MockDateTime + + +async def generate_signer_and_signed_headers_for_operation( + patch_datetime_now, operation_payload: dict +) -> tuple[LocalAccount, dict]: + """Generate a temporary eth_account for testing and sign the operation with it""" + account = eth_account.Account() + signer_account = account.create() + key = jwk.JWK.generate( + kty="EC", + crv="P-256", + # key_ops=["verify"], + ) + pubkey = { + "pubkey": json.loads(key.export_public()), + "alg": "ECDSA", + "domain": "localhost", + "address": signer_account.address, + "expires": (patch_datetime_now.FAKE_TIME + datetime.timedelta(days=1)).isoformat() + "Z", + } + pubkey_payload = json.dumps(pubkey).encode("utf-8").hex() + signable_message = eth_account.messages.encode_defunct(hexstr=pubkey_payload) + signed_message: SignedMessage = signer_account.sign_message(signable_message) + pubkey_signature = to_0x_hex(signed_message.signature) + pubkey_signature_header = json.dumps( + { + "payload": pubkey_payload, + "signature": pubkey_signature, + } + ) + payload_as_bytes = json.dumps(operation_payload).encode("utf-8") + + payload_signature = JWA.signing_alg("ES256").sign(key, payload_as_bytes) + headers = { + "X-SignedPubKey": pubkey_signature_header, + "X-SignedOperation": json.dumps( + { + "payload": payload_as_bytes.hex(), + "signature": payload_signature.hex(), + } + ), + } + return signer_account, headers + + +def to_0x_hex(b: bytes) -> str: + """ + Convert the bytes to a 0x-prefixed hex string + """ + + # force this for compat between different hexbytes versions which behave differenty + # and conflict with other package don't allow us to have the version we want + return "0x" + bytes.hex(b) diff --git a/tests/supervisor/test_authentication.py b/tests/supervisor/test_authentication.py index 249806f01..f4269a4ad 100644 --- a/tests/supervisor/test_authentication.py +++ b/tests/supervisor/test_authentication.py @@ -1,4 +1,3 @@ -import datetime import json from typing import Any @@ -8,22 +7,16 @@ from eth_account.datastructures import SignedMessage from jwcrypto import jwk, jws from jwcrypto.common import base64url_decode -from jwcrypto.jwa import JWA from aleph.vm.orchestrator.views.authentication import ( authenticate_jwk, require_jwk_authentication, ) - - -def to_0x_hex(b: bytes) -> str: - """ - Convert the bytes to a 0x-prefixed hex string - """ - - # force this for compat between different hexbytes versions which behave differenty - # and conflict with other package don't allow us to have the version we want - return "0x" + bytes.hex(b) +from aleph.vm.utils.test_helpers import ( + generate_signer_and_signed_headers_for_operation, + patch_datetime_now, + to_0x_hex, +) @pytest.mark.asyncio @@ -67,30 +60,6 @@ async def view(request, authenticated_sender): assert {"error": "Invalid X-SignedPubKey format"} == r -@pytest.fixture -def patch_datetime_now(monkeypatch): - """Fixture for patching the datetime.now() and datetime.utcnow() methods - to return a fixed datetime object. - This fixture creates a subclass of `datetime.datetime` called `mydatetime`, - which overrides the `now()` and `utcnow()` class methods to return a fixed - datetime object specified by `FAKE_TIME`. - """ - - class MockDateTime(datetime.datetime): - FAKE_TIME = datetime.datetime(2010, 12, 25, 17, 5, 55) - - @classmethod - def now(cls, tz=None, *args, **kwargs): - return cls.FAKE_TIME.replace(tzinfo=tz) - - @classmethod - def utcnow(cls, *args, **kwargs): - return cls.FAKE_TIME - - monkeypatch.setattr(datetime, "datetime", MockDateTime) - return MockDateTime - - @pytest.mark.asyncio async def test_require_jwk_authentication_expired(aiohttp_client): app = web.Application() @@ -257,32 +226,8 @@ async def test_require_jwk_authentication_good_key(aiohttp_client, patch_datetim """An HTTP request to a view decorated by `@require_jwk_authentication` auth correctly a temporary key signed by a wallet and an operation signed by that key""" app = web.Application() - - account = eth_account.Account() - signer_account = account.create() - key = jwk.JWK.generate( - kty="EC", - crv="P-256", - # key_ops=["verify"], - ) - - pubkey = { - "pubkey": json.loads(key.export_public()), - "alg": "ECDSA", - "domain": "localhost", - "address": signer_account.address, - "expires": (patch_datetime_now.FAKE_TIME + datetime.timedelta(days=1)).isoformat() + "Z", - } - pubkey_payload = json.dumps(pubkey).encode("utf-8").hex() - signable_message = eth_account.messages.encode_defunct(hexstr=pubkey_payload) - signed_message: SignedMessage = signer_account.sign_message(signable_message) - pubkey_signature = to_0x_hex(signed_message.signature) - pubkey_signature_header = json.dumps( - { - "payload": pubkey_payload, - "signature": pubkey_signature, - } - ) + payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/"} + signer_account, headers = await generate_signer_and_signed_headers_for_operation(patch_datetime_now, payload) @require_jwk_authentication async def view(request, authenticated_sender): @@ -292,18 +237,6 @@ async def view(request, authenticated_sender): app.router.add_get("", view) client = await aiohttp_client(app) - payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/"} - - payload_as_bytes = json.dumps(payload).encode("utf-8") - headers = {"X-SignedPubKey": pubkey_signature_header} - payload_signature = JWA.signing_alg("ES256").sign(key, payload_as_bytes) - headers["X-SignedOperation"] = json.dumps( - { - "payload": payload_as_bytes.hex(), - "signature": payload_signature.hex(), - } - ) - resp = await client.get("/", headers=headers) assert resp.status == 200, await resp.text() diff --git a/tests/supervisor/views/test_operator.py b/tests/supervisor/views/test_operator.py index 3e50fe53e..72a42ae09 100644 --- a/tests/supervisor/views/test_operator.py +++ b/tests/supervisor/views/test_operator.py @@ -1,6 +1,19 @@ +import asyncio +import datetime +import json +from asyncio import Queue + +import aiohttp import pytest +from aiohttp.test_utils import TestClient from aleph.vm.orchestrator.supervisor import setup_webapp +from aleph.vm.pool import VmPool +from aleph.vm.utils.logs import EntryDict +from aleph.vm.utils.test_helpers import ( + generate_signer_and_signed_headers_for_operation, + patch_datetime_now, +) @pytest.mark.asyncio @@ -36,3 +49,226 @@ class FakeVmPool: await response.text() == "Rebooted VM with ref fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_" ) assert pool.systemd_manager.restart.call_count == 1 + + +@pytest.mark.asyncio +async def test_logs(aiohttp_client, mocker): + mock_address = "mock_address" + mock_hash = "fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_" + mocker.patch( + "aleph.vm.orchestrator.views.authentication.authenticate_jwk", + return_value=mock_address, + ) + + # noinspection PyMissingConstructor + class FakeVmPool(VmPool): + def __init__(self): + pass + + executions = { + mock_hash: mocker.Mock( + vm_hash=mock_hash, + message=mocker.Mock(address=mock_address), + is_confidential=False, + is_running=True, + vm=mocker.Mock( + past_logs=mocker.Mock( + return_value=[ + EntryDict( + SYSLOG_IDENTIFIER="stdout", + MESSAGE="logline1", + __REALTIME_TIMESTAMP=datetime.datetime(2020, 10, 12, 1, 2), + ), + EntryDict( + SYSLOG_IDENTIFIER="stdout", + MESSAGE="logline2", + __REALTIME_TIMESTAMP=datetime.datetime(2020, 10, 12, 1, 3), + ), + ] + ) + ), + ), + } + systemd_manager = mocker.Mock(restart=mocker.Mock()) + + app = setup_webapp() + pool = FakeVmPool() + app["vm_pool"] = pool + app["pubsub"] = FakeVmPool() + client = await aiohttp_client(app) + response = await client.get( + f"/control/machine/{mock_hash}/logs", + ) + assert response.status == 200 + assert await response.text() == "2020-10-12T01:02:00> logline12020-10-12T01:03:00> logline2" + + +@pytest.mark.asyncio +async def test_websocket_logs(aiohttp_client, mocker): + mock_address = "mock_address" + mock_hash = "fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_" + mocker.patch( + "aleph.vm.orchestrator.views.operator.authenticate_websocket_message", + return_value=mock_address, + ) + fake_queue: Queue[tuple[str, str]] = asyncio.Queue() + await fake_queue.put(("stdout", "this is a first log entry")) + + fakeVmPool = mocker.Mock( + executions={ + mock_hash: mocker.Mock( + vm_hash=mock_hash, + message=mocker.Mock(address=mock_address), + is_confidential=False, + is_running=True, + vm=mocker.Mock( + get_log_queue=mocker.Mock(return_value=fake_queue), + ), + ), + }, + ) + app = setup_webapp() + app["vm_pool"] = fakeVmPool + app["pubsub"] = None + client = await aiohttp_client(app) + websocket = await client.ws_connect( + f"/control/machine/{mock_hash}/stream_logs", + ) + await websocket.send_json({"auth": "auth is disabled"}) + response = await websocket.receive_json() + assert response == {"status": "connected"} + + response = await websocket.receive_json() + assert response == {"message": "this is a first log entry", "type": "stdout"} + + await fake_queue.put(("stdout", "this is a second log entry")) + response = await websocket.receive_json() + assert response == {"message": "this is a second log entry", "type": "stdout"} + await websocket.close() + assert websocket.closed + + +@pytest.mark.asyncio +async def test_websocket_logs_missing_auth(aiohttp_client, mocker): + mock_address = "mock_address" + mock_hash = "fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_" + + fake_queue: Queue[tuple[str, str]] = asyncio.Queue() + await fake_queue.put(("stdout", "this is a first log entry")) + + fakeVmPool = mocker.Mock( + executions={ + mock_hash: mocker.Mock( + vm_hash=mock_hash, + message=mocker.Mock(address=mock_address), + is_confidential=False, + is_running=True, + vm=mocker.Mock( + get_log_queue=mocker.Mock(return_value=fake_queue), + ), + ), + }, + ) + app = setup_webapp() + app["vm_pool"] = fakeVmPool + app["pubsub"] = None + client = await aiohttp_client(app) + websocket = await client.ws_connect( + f"/control/machine/{mock_hash}/stream_logs", + ) + # Wait for message without sending an auth package. + # Test with a timeout because we receive nothing + with pytest.raises((TimeoutError, asyncio.exceptions.TimeoutError)): + response = await websocket.receive_json(timeout=1) + assert False + + # It's totally reachable with the pytest.raises + # noinspection PyUnreachableCode + await websocket.close() + assert websocket.closed + + +@pytest.mark.asyncio +async def test_websocket_logs_invalid_auth(aiohttp_client, mocker): + mock_address = "mock_address" + mock_hash = "fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_" + + fake_queue: Queue[tuple[str, str]] = asyncio.Queue() + await fake_queue.put(("stdout", "this is a first log entry")) + + fakeVmPool = mocker.Mock( + executions={ + mock_hash: mocker.Mock( + vm_hash=mock_hash, + message=mocker.Mock(address=mock_address), + is_confidential=False, + is_running=True, + vm=mocker.Mock( + get_log_queue=mocker.Mock(return_value=fake_queue), + ), + ), + }, + ) + app = setup_webapp() + app["vm_pool"] = fakeVmPool + app["pubsub"] = None + client: TestClient = await aiohttp_client(app) + websocket = await client.ws_connect( + f"/control/machine/{mock_hash}/stream_logs", + ) + + await websocket.send_json({"auth": "invalid auth package"}) + response = await websocket.receive() + # Subject to change in the future, for now the connexion si broken and closed + assert response.type == aiohttp.WSMsgType.CLOSE + assert websocket.closed + + +@pytest.mark.asyncio +async def test_websocket_logs_good_auth(aiohttp_client, mocker, patch_datetime_now): + "Test valid authentification for websocket logs endpoint" + payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/"} + signer_account, headers = await generate_signer_and_signed_headers_for_operation(patch_datetime_now, payload) + + mock_address = signer_account.address + mock_hash = "fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_fake_vm_" + + fake_queue: Queue[tuple[str, str]] = asyncio.Queue() + await fake_queue.put(("stdout", "this is a first log entry")) + + fakeVmPool = mocker.Mock( + executions={ + mock_hash: mocker.Mock( + vm_hash=mock_hash, + message=mocker.Mock(address=mock_address), + is_confidential=False, + is_running=True, + vm=mocker.Mock( + get_log_queue=mocker.Mock(return_value=fake_queue), + ), + ), + }, + ) + app = setup_webapp() + app["vm_pool"] = fakeVmPool + app["pubsub"] = None + client = await aiohttp_client(app) + websocket = await client.ws_connect( + f"/control/machine/{mock_hash}/stream_logs", + ) + # Need to deserialize since we pass a json otherwhise it get double json encoded + # which is not what the endpoint expect + auth_package = { + "X-SignedPubKey": json.loads(headers["X-SignedPubKey"]), + "X-SignedOperation": json.loads(headers["X-SignedOperation"]), + } + + await websocket.send_json({"auth": auth_package}) + response = await websocket.receive_json() + assert response == {"status": "connected"} + + response = await websocket.receive_json() + assert response == {"message": "this is a first log entry", "type": "stdout"} + + await websocket.close() + assert websocket.closed