Skip to content

Commit

Permalink
Fix: Solved issue getting already running executions with GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
nesitor committed Dec 6, 2024
1 parent 60b2491 commit ba1fc9d
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/aleph/vm/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
import uuid
from asyncio import Task
Expand All @@ -14,6 +15,7 @@
ProgramContent,
)
from aleph_message.models.execution.environment import GpuProperties, HypervisorType
from pydantic.json import pydantic_encoder

from aleph.vm.conf import settings
from aleph.vm.controllers.firecracker.executable import AlephFirecrackerExecutable
Expand Down Expand Up @@ -460,6 +462,7 @@ async def save(self):
message=self.message.json(),
original_message=self.original.json(),
persistent=self.persistent,
gpus=json.dumps(self.gpus, default=pydantic_encoder),
)
)

Expand Down
2 changes: 2 additions & 0 deletions src/aleph/vm/orchestrator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class ExecutionRecord(Base):
original_message = Column(JSON, nullable=True)
persistent = Column(Boolean, nullable=True)

gpus = Column(JSON, nullable=True)

def __repr__(self):
return f"<ExecutionRecord(uuid={self.uuid}, vm_hash={self.vm_hash}, vm_id={self.vm_id})>"

Expand Down
10 changes: 7 additions & 3 deletions src/aleph/vm/orchestrator/payment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,13 @@ async def get_stream(sender: str, receiver: str, chain: str) -> Decimal:
Get the stream of the user from the Superfluid API.
See https://community.aleph.im/t/pay-as-you-go-using-superfluid/98/11
"""
chain_info: ChainInfo = get_chain(chain=chain)
if not chain_info.active:
msg = f"Chain : {chain} is not active for superfluid"
try:
chain_info: ChainInfo = get_chain(chain=chain)

Check warning on line 104 in src/aleph/vm/orchestrator/payment.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/orchestrator/payment.py#L103-L104

Added lines #L103 - L104 were not covered by tests
if not chain_info.active:
msg = f"Chain : {chain} is not active for superfluid"
raise InvalidChainError(msg)
except ValueError:
msg = f"Chain : {chain} is invalid"

Check warning on line 109 in src/aleph/vm/orchestrator/payment.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/orchestrator/payment.py#L106-L109

Added lines #L106 - L109 were not covered by tests
raise InvalidChainError(msg)

superfluid_instance = CFA_V1(chain_info.rpc, chain_info.chain_id)
Expand Down
13 changes: 9 additions & 4 deletions src/aleph/vm/orchestrator/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import time
from collections.abc import AsyncIterable
from decimal import Decimal
from typing import TypeVar

import aiohttp
Expand Down Expand Up @@ -175,10 +176,14 @@ async def monitor_payments(app: web.Application):
# Check if the balance held in the wallet is sufficient stream tier resources
for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.superfluid).items():
for chain, executions in chains.items():
stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain)
logger.debug(
f"Get stream flow from Sender {sender} to Receiver {settings.PAYMENT_RECEIVER_ADDRESS} of {stream}"
)
try:
stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain)
logger.debug(

Check warning on line 181 in src/aleph/vm/orchestrator/tasks.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/orchestrator/tasks.py#L179-L181

Added lines #L179 - L181 were not covered by tests
f"Get stream flow from Sender {sender} to Receiver {settings.PAYMENT_RECEIVER_ADDRESS} of {stream}"
)
except ValueError as error:
logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}")
stream = Decimal(0)

Check warning on line 186 in src/aleph/vm/orchestrator/tasks.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/orchestrator/tasks.py#L184-L186

Added lines #L184 - L186 were not covered by tests

required_stream = await compute_required_flow(executions)
logger.debug(f"Required stream for Sender {sender} executions: {required_stream}")
Expand Down
4 changes: 2 additions & 2 deletions src/aleph/vm/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Payment,
PaymentType,
)
from pydantic import parse_raw_as

from aleph.vm.conf import settings
from aleph.vm.controllers.firecracker.snapshot_manager import SnapshotManager
Expand Down Expand Up @@ -241,8 +242,7 @@ async def load_persistent_executions(self):
if execution.is_running:
# TODO: Improve the way that we re-create running execution
# Load existing GPUs assigned to VMs
for saved_gpu in saved_execution.gpus:
execution.gpus.append(HostGPU(pci_host=saved_gpu.pci_host))
execution.gpus = parse_raw_as(List[HostGPU], saved_execution.gpus)

Check warning on line 245 in src/aleph/vm/pool.py

View check run for this annotation

Codecov / codecov/patch

src/aleph/vm/pool.py#L245

Added line #L245 was not covered by tests
# Load and instantiate the rest of resources and already assigned GPUs
await execution.prepare()
if self.network:
Expand Down
15 changes: 10 additions & 5 deletions src/aleph/vm/resources.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import subprocess
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional

from aleph_message.models import HashableModel
from pydantic import Extra, Field
from pydantic import BaseModel, Extra, Field


@dataclass
class HostGPU:
pci_host: str
class HostGPU(BaseModel):
"""Host GPU properties detail."""

pci_host: str = Field(description="GPU PCI host address")

class Config:
extra = Extra.forbid


class GpuDeviceClass(str, Enum):
"""GPU device class. Look at https://admin.pci-ids.ucw.cz/read/PD/03"""

VGA_COMPATIBLE_CONTROLLER = "0300"
_3D_CONTROLLER = "0302"

Expand Down

0 comments on commit ba1fc9d

Please sign in to comment.