diff --git a/src/aleph/vm/orchestrator/views/operator.py b/src/aleph/vm/orchestrator/views/operator.py index 876415d78..7e9482883 100644 --- a/src/aleph/vm/orchestrator/views/operator.py +++ b/src/aleph/vm/orchestrator/views/operator.py @@ -204,7 +204,9 @@ async def operate_erase(request: web.Request, authenticated_sender: str) -> web. # Stop the VM await pool.stop_vm(execution.vm_hash) - pool.forget_vm(execution.vm_hash) + if execution.vm_hash in pool.executions: + logger.warning(f"VM {execution.vm_hash} was not stopped properly, forgetting it anyway") + pool.forget_vm(execution.vm_hash) # Delete all data if execution.resources is not None: diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index eb9d7ec48..0c1673cee 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import json import logging @@ -128,11 +130,7 @@ async def create_a_vm( self.forget_vm(vm_hash) raise - async def forget_on_stop(stop_event: asyncio.Event): - await stop_event.wait() - self.forget_vm(vm_hash) - - asyncio.create_task(forget_on_stop(stop_event=execution.stop_event)) + self._schedule_forget_on_stop(execution) return execution @@ -184,7 +182,7 @@ async def stop_vm(self, vm_hash: ItemHash) -> Optional[VmExecution]: else: return None - async def stop_persistent_execution(self, execution): + async def stop_persistent_execution(self, execution: VmExecution): """Stop persistent VMs in the pool.""" assert execution.persistent, "Execution isn't persistent" self.systemd_manager.stop_and_disable(execution.controller_service) @@ -202,31 +200,45 @@ def forget_vm(self, vm_hash: ItemHash) -> None: except KeyError: pass + def _schedule_forget_on_stop(self, execution: VmExecution): + """Create a task that will remove the VM from the pool after it stops.""" + + async def forget_on_stop(stop_event: asyncio.Event): + await stop_event.wait() + self.forget_vm(execution.vm_hash) + + _ = asyncio.create_task(forget_on_stop(stop_event=execution.stop_event)) + async def _load_persistent_executions(self): """Load persistent executions from the database.""" saved_executions = await get_execution_records() for saved_execution in saved_executions: - # Prevent to load the same execution twice - if self.executions.get(saved_execution.vm_hash): + vm_hash = ItemHash(saved_execution.vm_hash) + + if vm_hash in self.executions: + # The execution is already loaded, skip it continue vm_id = saved_execution.vm_id + message_dict = json.loads(saved_execution.message) original_dict = json.loads(saved_execution.original_message) + execution = VmExecution( - vm_hash=saved_execution.vm_hash, + vm_hash=vm_hash, message=get_message_executable_content(message_dict), - original=get_message_executable_content(message_dict), + original=get_message_executable_content(original_dict), snapshot_manager=self.snapshot_manager, systemd_manager=self.systemd_manager, persistent=saved_execution.persistent, ) + if execution.is_running: # TODO: Improve the way that we re-create running execution await execution.prepare() if self.network: vm_type = VmType.from_message_content(execution.message) - tap_interface = await self.network.prepare_tap(vm_id, execution.vm_hash, vm_type) + tap_interface = await self.network.prepare_tap(vm_id, vm_hash, vm_type) else: tap_interface = None @@ -235,7 +247,9 @@ async def _load_persistent_executions(self): execution.ready_event.set() execution.times.started_at = datetime.now(tz=timezone.utc) - self.executions[execution.vm_hash] = execution + self._schedule_forget_on_stop(execution) + + self.executions[vm_hash] = execution else: execution.uuid = saved_execution.uuid await execution.record_usage()