Skip to content

Commit

Permalink
Cleanup: Minor code cleanup and refactoring (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoh authored Feb 20, 2024
1 parent a470f4e commit e37bb6f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/aleph/vm/orchestrator/views/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ async def operate_erase(request: web.Request, authenticated_sender: str) -> web.

# Stop the VM
await pool.stop_vm(execution.vm_hash)
pool.forget_vm(execution.vm_hash)
if execution.vm_hash in pool.executions:
logger.warning(f"VM {execution.vm_hash} was not stopped properly, forgetting it anyway")
pool.forget_vm(execution.vm_hash)

# Delete all data
if execution.resources is not None:
Expand Down
38 changes: 26 additions & 12 deletions src/aleph/vm/pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import json
import logging
Expand Down Expand Up @@ -128,11 +130,7 @@ async def create_a_vm(
self.forget_vm(vm_hash)
raise

async def forget_on_stop(stop_event: asyncio.Event):
await stop_event.wait()
self.forget_vm(vm_hash)

asyncio.create_task(forget_on_stop(stop_event=execution.stop_event))
self._schedule_forget_on_stop(execution)

return execution

Expand Down Expand Up @@ -184,7 +182,7 @@ async def stop_vm(self, vm_hash: ItemHash) -> Optional[VmExecution]:
else:
return None

async def stop_persistent_execution(self, execution):
async def stop_persistent_execution(self, execution: VmExecution):
"""Stop persistent VMs in the pool."""
assert execution.persistent, "Execution isn't persistent"
self.systemd_manager.stop_and_disable(execution.controller_service)
Expand All @@ -202,31 +200,45 @@ def forget_vm(self, vm_hash: ItemHash) -> None:
except KeyError:
pass

def _schedule_forget_on_stop(self, execution: VmExecution):
"""Create a task that will remove the VM from the pool after it stops."""

async def forget_on_stop(stop_event: asyncio.Event):
await stop_event.wait()
self.forget_vm(execution.vm_hash)

_ = asyncio.create_task(forget_on_stop(stop_event=execution.stop_event))

async def _load_persistent_executions(self):
"""Load persistent executions from the database."""
saved_executions = await get_execution_records()
for saved_execution in saved_executions:
# Prevent to load the same execution twice
if self.executions.get(saved_execution.vm_hash):
vm_hash = ItemHash(saved_execution.vm_hash)

if vm_hash in self.executions:
# The execution is already loaded, skip it
continue

vm_id = saved_execution.vm_id

message_dict = json.loads(saved_execution.message)
original_dict = json.loads(saved_execution.original_message)

execution = VmExecution(
vm_hash=saved_execution.vm_hash,
vm_hash=vm_hash,
message=get_message_executable_content(message_dict),
original=get_message_executable_content(message_dict),
original=get_message_executable_content(original_dict),
snapshot_manager=self.snapshot_manager,
systemd_manager=self.systemd_manager,
persistent=saved_execution.persistent,
)

if execution.is_running:
# TODO: Improve the way that we re-create running execution
await execution.prepare()
if self.network:
vm_type = VmType.from_message_content(execution.message)
tap_interface = await self.network.prepare_tap(vm_id, execution.vm_hash, vm_type)
tap_interface = await self.network.prepare_tap(vm_id, vm_hash, vm_type)
else:
tap_interface = None

Expand All @@ -235,7 +247,9 @@ async def _load_persistent_executions(self):
execution.ready_event.set()
execution.times.started_at = datetime.now(tz=timezone.utc)

self.executions[execution.vm_hash] = execution
self._schedule_forget_on_stop(execution)

self.executions[vm_hash] = execution
else:
execution.uuid = saved_execution.uuid
await execution.record_usage()
Expand Down

0 comments on commit e37bb6f

Please sign in to comment.