diff --git a/pyproject.toml b/pyproject.toml index 50c314c2..d33f05a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,15 +191,16 @@ lint.ignore = [ # Allow the use of assert statements "S101", ] -# Tests can use magic values, assertions, and relative imports -lint.per-file-ignores."tests/**/*" = [ "PLR2004", "S101", "TID252" ] #[tool.ruff.flake8-tidy-imports] #ban-relative-imports = "all" #unfixable = [ # # Don't touch unused imports # "F401", #] -lint.isort = [ "aleph.vm" ] +#lint.isort = [ "aleph.vm" ] + +# Tests can use magic values, assertions, and relative imports +lint.per-file-ignores."tests/**/*" = [ "PLR2004", "S101", "TID252" ] [tool.pytest.ini_options] pythonpath = [ diff --git a/src/aleph/vm/conf.py b/src/aleph/vm/conf.py index b68ff9e8..739cfda9 100644 --- a/src/aleph/vm/conf.py +++ b/src/aleph/vm/conf.py @@ -136,6 +136,7 @@ class Settings(BaseSettings): # System logs make boot ~2x slower PRINT_SYSTEM_LOGS = False IGNORE_TRACEBACK_FROM_DIAGNOSTICS = True + LOG_LEVEL = "WARNING" DEBUG_ASYNCIO = False # Networking does not work inside Docker/Podman @@ -396,8 +397,6 @@ def setup(self): STREAM_CHAINS[Chain.AVAX].rpc = str(self.RPC_AVAX) STREAM_CHAINS[Chain.BASE].rpc = str(self.RPC_BASE) - logger.info(STREAM_CHAINS) - os.makedirs(self.MESSAGE_CACHE, exist_ok=True) os.makedirs(self.CODE_CACHE, exist_ok=True) os.makedirs(self.RUNTIME_CACHE, exist_ok=True) diff --git a/src/aleph/vm/orchestrator/cli.py b/src/aleph/vm/orchestrator/cli.py index ddcf8910..bbae396d 100644 --- a/src/aleph/vm/orchestrator/cli.py +++ b/src/aleph/vm/orchestrator/cli.py @@ -23,6 +23,7 @@ from aleph.vm.version import __version__, get_version_from_apt, get_version_from_git from . import metrics, supervisor +from .custom_logs import setup_handlers from .pubsub import PubSub from .run import run_code_on_event, run_code_on_request, start_persistent_vm @@ -65,7 +66,7 @@ def parse_args(args): help="set loglevel to INFO", action="store_const", const=logging.INFO, - default=logging.WARNING, + default=settings.LOG_LEVEL, ) parser.add_argument( "-vv", @@ -282,7 +283,7 @@ def run_db_migrations(connection): async def run_async_db_migrations(): - async_engine = create_async_engine(make_db_url(), echo=True) + async_engine = create_async_engine(make_db_url(), echo=False) async with async_engine.begin() as conn: await conn.run_sync(run_db_migrations) @@ -293,13 +294,20 @@ def main(): log_format = ( "%(relativeCreated)4f | %(levelname)s | %(message)s" if args.profile - else "%(asctime)s | %(levelname)s | %(message)s" + else "%(asctime)s | %(levelname)s %(name)s:%(lineno)s | %(message)s" ) + # log_format = "[%(asctime)s] p%(process)s {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s" + + handlers = setup_handlers(args, log_format) logging.basicConfig( level=args.loglevel, format=log_format, + handlers=handlers, ) + logging.getLogger("aiosqlite").setLevel(settings.LOG_LEVEL) + logging.getLogger("sqlalchemy.engine").setLevel(settings.LOG_LEVEL) + settings.update( USE_JAILER=args.use_jailer, PRINT_SYSTEM_LOGS=args.system_logs, diff --git a/src/aleph/vm/orchestrator/custom_logs.py b/src/aleph/vm/orchestrator/custom_logs.py new file mode 100644 index 00000000..9150fdd7 --- /dev/null +++ b/src/aleph/vm/orchestrator/custom_logs.py @@ -0,0 +1,54 @@ +import contextlib +import logging +from contextvars import ContextVar + +from aleph_message.models import ItemHash + +from aleph.vm.models import VmExecution + +ctx_current_execution: ContextVar[VmExecution | None] = ContextVar("current_execution") +ctx_current_execution_hash: ContextVar[ItemHash | None] = ContextVar("current_execution_hash") + + +@contextlib.contextmanager +def set_vm_for_logging(vm_hash): + token = ctx_current_execution_hash.set(vm_hash) + try: + yield + finally: + ctx_current_execution_hash.reset(token) + + +class InjectingFilter(logging.Filter): + """ + A filter which injects context-specific information into logs + """ + + def filter(self, record): + + vm_hash = ctx_current_execution_hash.get(None) + if not vm_hash: + vm_execution: VmExecution | None = ctx_current_execution.get(None) + if vm_execution: + vm_hash = vm_execution.vm_hash + + if not vm_hash: + return False + + record.vm_hash = vm_hash + return True + + +def setup_handlers(args, log_format): + # Set up two custom handler, one that will add the VM information if present and the other print if not + execution_handler = logging.StreamHandler() + execution_handler.addFilter(InjectingFilter()) + execution_handler.setFormatter( + logging.Formatter("%(asctime)s | %(levelname)s %(name)s:%(lineno)s | {%(vm_hash)s} %(message)s ") + ) + non_execution_handler = logging.StreamHandler() + non_execution_handler.addFilter(lambda x: ctx_current_execution_hash.get(None) is None) + non_execution_handler.setFormatter( + logging.Formatter("%(asctime)s | %(levelname)s %(name)s:%(lineno)s | %(message)s ") + ) + return [non_execution_handler, execution_handler] diff --git a/src/aleph/vm/orchestrator/metrics.py b/src/aleph/vm/orchestrator/metrics.py index 3b8cdf9f..67222521 100644 --- a/src/aleph/vm/orchestrator/metrics.py +++ b/src/aleph/vm/orchestrator/metrics.py @@ -38,7 +38,7 @@ def setup_engine(): global AsyncSessionMaker - engine = create_async_engine(make_db_url(), echo=True) + engine = create_async_engine(make_db_url(), echo=False) AsyncSessionMaker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) return engine diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 4bba01aa..4c9b3866 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -1,4 +1,5 @@ import binascii +import contextlib import logging from decimal import Decimal from hashlib import sha256 @@ -25,6 +26,7 @@ from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError from aleph.vm.orchestrator import payment, status from aleph.vm.orchestrator.chain import STREAM_CHAINS, ChainInfo +from aleph.vm.orchestrator.custom_logs import set_vm_for_logging from aleph.vm.orchestrator.messages import try_get_message from aleph.vm.orchestrator.metrics import get_execution_records from aleph.vm.orchestrator.payment import ( @@ -75,7 +77,8 @@ async def run_code_from_path(request: web.Request) -> web.Response: ) from e pool: VmPool = request.app["vm_pool"] - return await run_code_on_request(message_ref, path, pool, request) + with set_vm_for_logging(vm_hash=message_ref): + return await run_code_on_request(message_ref, path, pool, request) async def run_code_from_hostname(request: web.Request) -> web.Response: @@ -112,7 +115,8 @@ async def run_code_from_hostname(request: web.Request) -> web.Response: return HTTPNotFound(reason="Invalid message reference") pool = request.app["vm_pool"] - return await run_code_on_request(message_ref, path, pool, request) + with set_vm_for_logging(vm_hash=message_ref): + return await run_code_on_request(message_ref, path, pool, request) def authenticate_request(request: web.Request) -> None: diff --git a/src/aleph/vm/orchestrator/views/operator.py b/src/aleph/vm/orchestrator/views/operator.py index af0e98f4..72218f3e 100644 --- a/src/aleph/vm/orchestrator/views/operator.py +++ b/src/aleph/vm/orchestrator/views/operator.py @@ -15,6 +15,7 @@ from aleph.vm.conf import settings from aleph.vm.controllers.qemu.client import QemuVmClient from aleph.vm.models import VmExecution +from aleph.vm.orchestrator.custom_logs import set_vm_for_logging from aleph.vm.orchestrator.run import create_vm_execution_or_raise_http_error from aleph.vm.orchestrator.views.authentication import ( authenticate_websocket_message, @@ -63,36 +64,37 @@ async def stream_logs(request: web.Request) -> web.StreamResponse: allow Javascript to set headers in WebSocket requests. """ vm_hash = get_itemhash_or_400(request.match_info) - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) - if execution.vm is None: - raise web.HTTPBadRequest(body=f"VM {vm_hash} is not running") - queue = None - try: - ws = web.WebSocketResponse() - logger.info(f"starting websocket: {request.path}") - await ws.prepare(request) + if execution.vm is None: + raise web.HTTPBadRequest(body=f"VM {vm_hash} is not running") + queue = None try: - await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws) - await ws.send_json({"status": "connected"}) + ws = web.WebSocketResponse() + logger.info(f"starting websocket: {request.path}") + await ws.prepare(request) + try: + await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws) + await ws.send_json({"status": "connected"}) - queue = execution.vm.get_log_queue() + queue = execution.vm.get_log_queue() - while True: - log_type, message = await queue.get() - assert log_type in ("stdout", "stderr") - logger.debug(message) + while True: + log_type, message = await queue.get() + assert log_type in ("stdout", "stderr") + logger.debug(message) - await ws.send_json({"type": log_type, "message": message}) + await ws.send_json({"type": log_type, "message": message}) - finally: - await ws.close() - logger.info(f"connection {ws} closed") + finally: + await ws.close() + logger.info(f"connection {ws} closed") - finally: - if queue: - execution.vm.unregister_queue(queue) + finally: + if queue: + execution.vm.unregister_queue(queue) @cors_allow_all @@ -100,20 +102,21 @@ async def stream_logs(request: web.Request) -> web.StreamResponse: async def operate_logs(request: web.Request, authenticated_sender: str) -> web.StreamResponse: """Logs of a VM (not streaming)""" vm_hash = get_itemhash_or_400(request.match_info) - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - response = web.StreamResponse() - response.headers["Content-Type"] = "text/plain" - await response.prepare(request) + response = web.StreamResponse() + response.headers["Content-Type"] = "text/plain" + await response.prepare(request) - for entry in execution.vm.past_logs(): - msg = f'{entry["__REALTIME_TIMESTAMP"].isoformat()}> {entry["MESSAGE"]}' - await response.write(msg.encode()) - await response.write_eof() - return response + for entry in execution.vm.past_logs(): + msg = f'{entry["__REALTIME_TIMESTAMP"].isoformat()}> {entry["MESSAGE"]}' + await response.write(msg.encode()) + await response.write_eof() + return response async def authenticate_websocket_for_vm_or_403(execution: VmExecution, vm_hash: ItemHash, ws: web.WebSocketResponse): @@ -154,24 +157,25 @@ async def operate_expire(request: web.Request, authenticated_sender: str) -> web A timeout may be specified to delay the action.""" vm_hash = get_itemhash_or_400(request.match_info) - try: - timeout = float(ItemHash(request.match_info["timeout"])) - except (KeyError, ValueError) as error: - raise web.HTTPBadRequest(body="Invalid timeout duration") from error - if not 0 < timeout < timedelta(days=10).total_seconds(): - return web.HTTPBadRequest(body="Invalid timeout duration") + with set_vm_for_logging(vm_hash=vm_hash): + try: + timeout = float(ItemHash(request.match_info["timeout"])) + except (KeyError, ValueError) as error: + raise web.HTTPBadRequest(body="Invalid timeout duration") from error + if not 0 < timeout < timedelta(days=10).total_seconds(): + return web.HTTPBadRequest(body="Invalid timeout duration") - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - logger.info(f"Expiring in {timeout} seconds: {execution.vm_hash}") - await execution.expire(timeout=timeout) - execution.persistent = False + logger.info(f"Expiring in {timeout} seconds: {execution.vm_hash}") + await execution.expire(timeout=timeout) + execution.persistent = False - return web.Response(status=200, body=f"Expiring VM with ref {vm_hash} in {timeout} seconds") + return web.Response(status=200, body=f"Expiring VM with ref {vm_hash} in {timeout} seconds") @cors_allow_all @@ -179,53 +183,54 @@ async def operate_expire(request: web.Request, authenticated_sender: str) -> web async def operate_confidential_initialize(request: web.Request, authenticated_sender: str) -> web.Response: """Start the confidential virtual machine if possible.""" vm_hash = get_itemhash_or_400(request.match_info) + with set_vm_for_logging(vm_hash=vm_hash): - pool: VmPool = request.app["vm_pool"] - logger.debug(f"Iterating through running executions... {pool.executions}") - execution = get_execution_or_404(vm_hash, pool=pool) + pool: VmPool = request.app["vm_pool"] + logger.debug(f"Iterating through running executions... {pool.executions}") + execution = get_execution_or_404(vm_hash, pool=pool) - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - if execution.is_running: - return web.json_response( - {"code": "vm_running", "description": "Operation not allowed, instance already running"}, - status=HTTPStatus.BAD_REQUEST, - ) - if not execution.is_confidential: - return web.json_response( - {"code": "not_confidential", "description": "Instance is not a confidential instance"}, - status=HTTPStatus.BAD_REQUEST, - ) + if execution.is_running: + return web.json_response( + {"code": "vm_running", "description": "Operation not allowed, instance already running"}, + status=HTTPStatus.BAD_REQUEST, + ) + if not execution.is_confidential: + return web.json_response( + {"code": "not_confidential", "description": "Instance is not a confidential instance"}, + status=HTTPStatus.BAD_REQUEST, + ) - post = await request.post() + post = await request.post() - vm_session_path = settings.CONFIDENTIAL_SESSION_DIRECTORY / vm_hash - vm_session_path.mkdir(exist_ok=True) + vm_session_path = settings.CONFIDENTIAL_SESSION_DIRECTORY / vm_hash + vm_session_path.mkdir(exist_ok=True) - session_file_content = post.get("session") - if not session_file_content: - return web.json_response( - {"code": "field_missing", "description": "Session field is missing"}, - status=HTTPStatus.BAD_REQUEST, - ) + session_file_content = post.get("session") + if not session_file_content: + return web.json_response( + {"code": "field_missing", "description": "Session field is missing"}, + status=HTTPStatus.BAD_REQUEST, + ) - session_file_path = vm_session_path / "vm_session.b64" - session_file_path.write_bytes(session_file_content.file.read()) + session_file_path = vm_session_path / "vm_session.b64" + session_file_path.write_bytes(session_file_content.file.read()) - godh_file_content = post.get("godh") - if not godh_file_content: - return web.json_response( - {"code": "field_missing", "description": "godh field is missing. Please provide a GODH file"}, - status=HTTPStatus.BAD_REQUEST, - ) + godh_file_content = post.get("godh") + if not godh_file_content: + return web.json_response( + {"code": "field_missing", "description": "godh field is missing. Please provide a GODH file"}, + status=HTTPStatus.BAD_REQUEST, + ) - godh_file_path = vm_session_path / "vm_godh.b64" - godh_file_path.write_bytes(godh_file_content.file.read()) + godh_file_path = vm_session_path / "vm_godh.b64" + godh_file_path.write_bytes(godh_file_content.file.read()) - pool.systemd_manager.enable_and_start(execution.controller_service) + pool.systemd_manager.enable_and_start(execution.controller_service) - return web.Response(status=200, body=f"Started VM with ref {vm_hash}") + return web.Response(status=200, body=f"Started VM with ref {vm_hash}") @cors_allow_all @@ -233,23 +238,23 @@ async def operate_confidential_initialize(request: web.Request, authenticated_se async def operate_stop(request: web.Request, authenticated_sender: str) -> web.Response: """Stop the virtual machine, smoothly if possible.""" vm_hash = get_itemhash_or_400(request.match_info) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + logger.debug(f"Iterating through running executions... {pool.executions}") + execution = get_execution_or_404(vm_hash, pool=pool) - pool: VmPool = request.app["vm_pool"] - logger.debug(f"Iterating through running executions... {pool.executions}") - execution = get_execution_or_404(vm_hash, pool=pool) - - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - if execution.is_running: - logger.info(f"Stopping {execution.vm_hash}") - await pool.stop_vm(execution.vm_hash) - return web.Response(status=200, body=f"Stopped VM with ref {vm_hash}") - else: - return web.Response(status=200, body="Already stopped, nothing to do") + if execution.is_running: + logger.info(f"Stopping {execution.vm_hash}") + await pool.stop_vm(execution.vm_hash) + return web.Response(status=200, body=f"Stopped VM with ref {vm_hash}") + else: + return web.Response(status=200, body="Already stopped, nothing to do") @cors_allow_all @@ -259,24 +264,25 @@ async def operate_reboot(request: web.Request, authenticated_sender: str) -> web Reboots the virtual machine, smoothly if possible. """ vm_hash = get_itemhash_or_400(request.match_info) - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) - - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") - - if execution.is_running: - logger.info(f"Rebooting {execution.vm_hash}") - if execution.persistent: - pool.systemd_manager.restart(execution.controller_service) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) + + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") + + if execution.is_running: + logger.info(f"Rebooting {execution.vm_hash}") + if execution.persistent: + pool.systemd_manager.restart(execution.controller_service) + else: + await pool.stop_vm(vm_hash) + pool.forget_vm(vm_hash) + + await create_vm_execution_or_raise_http_error(vm_hash=vm_hash, pool=pool) + return web.Response(status=200, body=f"Rebooted VM with ref {vm_hash}") else: - await pool.stop_vm(vm_hash) - pool.forget_vm(vm_hash) - - await create_vm_execution_or_raise_http_error(vm_hash=vm_hash, pool=pool) - return web.Response(status=200, body=f"Rebooted VM with ref {vm_hash}") - else: - return web.Response(status=200, body=f"Starting VM (was not running) with ref {vm_hash}") + return web.Response(status=200, body=f"Starting VM (was not running) with ref {vm_hash}") @cors_allow_all @@ -286,23 +292,24 @@ async def operate_confidential_measurement(request: web.Request, authenticated_s Fetch the sev measurement for the VM """ vm_hash = get_itemhash_or_400(request.match_info) - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - if not execution.is_running: - raise web.HTTPForbidden(body="Operation not running") - vm_client = QemuVmClient(execution.vm) - vm_sev_info = vm_client.query_sev_info() - launch_measure = vm_client.query_launch_measure() + if not execution.is_running: + raise web.HTTPForbidden(body="Operation not running") + vm_client = QemuVmClient(execution.vm) + vm_sev_info = vm_client.query_sev_info() + launch_measure = vm_client.query_launch_measure() - return web.json_response( - data={"sev_info": vm_sev_info, "launch_measure": launch_measure}, - status=200, - dumps=dumps_for_json, - ) + return web.json_response( + data={"sev_info": vm_sev_info, "launch_measure": launch_measure}, + status=200, + dumps=dumps_for_json, + ) class InjectSecretParams(BaseModel): @@ -330,25 +337,26 @@ async def operate_confidential_inject_secret(request: web.Request, authenticated return web.json_response(data=error.json(), status=web.HTTPBadRequest.status_code) vm_hash = get_itemhash_or_400(request.match_info) - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - # if not execution.is_running: - # raise web.HTTPForbidden(body="Operation not running") - vm_client = QemuVmClient(execution.vm) - vm_client.inject_secret(params.packet_header, params.secret) - vm_client.continue_execution() + # if not execution.is_running: + # raise web.HTTPForbidden(body="Operation not running") + vm_client = QemuVmClient(execution.vm) + vm_client.inject_secret(params.packet_header, params.secret) + vm_client.continue_execution() - status = vm_client.query_status() - print(status["status"] != "running") + status = vm_client.query_status() + print(status["status"] != "running") - return web.json_response( - data={"status": status}, - status=200, - dumps=dumps_for_json, - ) + return web.json_response( + data={"status": status}, + status=200, + dumps=dumps_for_json, + ) @cors_allow_all @@ -358,25 +366,26 @@ async def operate_erase(request: web.Request, authenticated_sender: str) -> web. Stop the virtual machine first if needed. """ vm_hash = get_itemhash_or_400(request.match_info) - pool: VmPool = request.app["vm_pool"] - execution = get_execution_or_404(vm_hash, pool=pool) + with set_vm_for_logging(vm_hash=vm_hash): + pool: VmPool = request.app["vm_pool"] + execution = get_execution_or_404(vm_hash, pool=pool) - if not is_sender_authorized(authenticated_sender, execution.message): - return web.Response(status=403, body="Unauthorized sender") + if not is_sender_authorized(authenticated_sender, execution.message): + return web.Response(status=403, body="Unauthorized sender") - logger.info(f"Erasing {execution.vm_hash}") + logger.info(f"Erasing {execution.vm_hash}") - # Stop the VM - await pool.stop_vm(execution.vm_hash) - if execution.vm_hash in pool.executions: - logger.warning(f"VM {execution.vm_hash} was not stopped properly, forgetting it anyway") - pool.forget_vm(execution.vm_hash) - - # Delete all data - if execution.resources is not None: - for volume in execution.resources.volumes: - if not volume.read_only: - logger.info(f"Deleting volume {volume.path_on_host}") - volume.path_on_host.unlink() - - return web.Response(status=200, body=f"Erased VM with ref {vm_hash}") + # Stop the VM + await pool.stop_vm(execution.vm_hash) + if execution.vm_hash in pool.executions: + logger.warning(f"VM {execution.vm_hash} was not stopped properly, forgetting it anyway") + pool.forget_vm(execution.vm_hash) + + # Delete all data + if execution.resources is not None: + for volume in execution.resources.volumes: + if not volume.read_only: + logger.info(f"Deleting volume {volume.path_on_host}") + volume.path_on_host.unlink() + + return web.Response(status=200, body=f"Erased VM with ref {vm_hash}") diff --git a/src/aleph/vm/version.py b/src/aleph/vm/version.py index ba4f3433..73118aa7 100644 --- a/src/aleph/vm/version.py +++ b/src/aleph/vm/version.py @@ -1,17 +1,17 @@ import logging -from subprocess import CalledProcessError, check_output +from subprocess import STDOUT, CalledProcessError, check_output logger = logging.getLogger(__name__) def get_version_from_git() -> str | None: try: - return check_output(("git", "describe", "--tags")).strip().decode() + return check_output(("git", "describe", "--tags"), stderr=STDOUT).strip().decode() except FileNotFoundError: - logger.warning("git not found") + logger.warning("version: git not found") return None - except CalledProcessError: - logger.warning("git description not available") + except CalledProcessError as err: + logger.info("version: git description not available: %s", err.output.decode().strip()) return None diff --git a/tests/supervisor/views/test_operator.py b/tests/supervisor/views/test_operator.py index b8e370de..86c6c5cd 100644 --- a/tests/supervisor/views/test_operator.py +++ b/tests/supervisor/views/test_operator.py @@ -102,6 +102,82 @@ async def test_operator_confidential_initialize_already_running(aiohttp_client, } +@pytest.mark.asyncio +@pytest.mark.skip() +async def test_operator_expire(aiohttp_client, mocker): + """Test that the expires endpoint work. SPOILER it doesn't""" + + settings.ENABLE_QEMU_SUPPORT = True + settings.ENABLE_CONFIDENTIAL_COMPUTING = True + settings.setup() + + vm_hash = ItemHash(settings.FAKE_INSTANCE_ID) + instance_message = await get_message(ref=vm_hash) + + fake_vm_pool = mocker.Mock( + executions={ + vm_hash: mocker.Mock( + vm_hash=vm_hash, + message=instance_message.content, + is_confidential=False, + is_running=False, + ), + }, + ) + + # Disable auth + mocker.patch( + "aleph.vm.orchestrator.views.authentication.authenticate_jwk", + return_value=instance_message.sender, + ) + app = setup_webapp() + app["vm_pool"] = fake_vm_pool + client: TestClient = await aiohttp_client(app) + response = await client.post( + f"/control/machine/{vm_hash}/expire", + data={"timeout": 1}, + # json={"timeout": 1}, + ) + assert response.status == 200, await response.text() + assert fake_vm_pool["executions"][vm_hash].expire.call_count == 1 + + +@pytest.mark.asyncio +async def test_operator_stop(aiohttp_client, mocker): + """Test that the stop endpoint call the method on pool""" + + settings.ENABLE_QEMU_SUPPORT = True + settings.ENABLE_CONFIDENTIAL_COMPUTING = True + settings.setup() + + vm_hash = ItemHash(settings.FAKE_INSTANCE_ID) + instance_message = await get_message(ref=vm_hash) + + fake_vm_pool = mocker.AsyncMock( + executions={ + vm_hash: mocker.AsyncMock( + vm_hash=vm_hash, + message=instance_message.content, + is_running=True, + ), + }, + ) + + # Disable auth + mocker.patch( + "aleph.vm.orchestrator.views.authentication.authenticate_jwk", + return_value=instance_message.sender, + ) + app = setup_webapp() + app["vm_pool"] = fake_vm_pool + client: TestClient = await aiohttp_client(app) + response = await client.post( + f"/control/machine/{vm_hash}/stop", + ) + assert response.status == 200, await response.text() + assert fake_vm_pool.stop_vm.call_count == 1 + + @pytest.mark.asyncio async def test_operator_confidential_initialize_not_confidential(aiohttp_client, mocker): """Test that the confidential initialize endpoint rejects if the VM is not confidential"""