Skip to content

Commit

Permalink
add ability to subscribe to buildflow events
Browse files Browse the repository at this point in the history
  • Loading branch information
boetro committed Nov 27, 2023
1 parent 2e92311 commit d951778
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 54 deletions.
15 changes: 14 additions & 1 deletion buildflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime
from pprint import pprint
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import pathspec
import ray
Expand Down Expand Up @@ -55,6 +55,7 @@ def run_flow(
serve_host: str,
serve_port: int,
flow_state: Optional[FlowState] = None,
event_subscriber: Optional[Callable] = None,
):
if isinstance(flow, buildflow.Flow):
if reload:
Expand All @@ -67,6 +68,7 @@ def run_flow(
runtime_server_port=runtime_server_port,
serve_host=serve_host,
serve_port=serve_port,
event_subscriber=event_subscriber,
)
asyncio.run(watcher.run())

Expand All @@ -79,6 +81,7 @@ def run_flow(
flow_state=flow_state,
serve_host=serve_host,
serve_port=serve_port,
event_subscriber=event_subscriber,
)
else:
typer.echo(f"{app} is not a buildflow flow.")
Expand Down Expand Up @@ -111,7 +114,15 @@ def run(
"",
help="The location the build will be decompressed to. Only relevent if setting --from-build, defaults to a temporary directory", # noqa
),
event_subscriber: str = typer.Option(
None,
help="A python function that will be called when a status change occurs in the runtime.", # noqa
),
):
event_subscriber_import = None
if event_subscriber:
sys.path.insert(0, "")
event_subscriber_import = utils.import_from_string(event_subscriber)
if not from_build:
buildflow_config = BuildFlowConfig.load()
sys.path.insert(0, "")
Expand All @@ -126,6 +137,7 @@ def run(
runtime_server_port,
serve_host,
serve_port,
event_subscriber=event_subscriber_import,
)
else:
if reload:
Expand Down Expand Up @@ -157,6 +169,7 @@ def run(
runtime_server_host,
runtime_server_port,
flow_state=flow_state,
envent_subscriber=event_subscriber_import,
)


Expand Down
3 changes: 3 additions & 0 deletions buildflow/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import builtins
import importlib
import sys
from typing import Any
Expand All @@ -22,6 +23,8 @@ def import_from_string(import_str: str) -> Any:
import_str = f"{filename[:-3]}:{varname}"
else:
raise ValueError(f"Could not find Flow object in {filename}")
if hasattr(builtins, import_str):
return getattr(builtins, import_str)
module_str, _, attrs_str = import_str.partition(":")
if not module_str or not attrs_str:
message = (
Expand Down
4 changes: 4 additions & 0 deletions buildflow/cli/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import signal
import sys
from multiprocessing import get_context
from typing import Callable, Optional

from watchfiles import awatch

Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
runtime_server_port: int,
serve_host: str,
serve_port: int,
event_subscriber: Optional[Callable],
) -> None:
self.app = app
self.start_runtime_server = start_runtime_server
Expand All @@ -53,6 +55,7 @@ def __init__(
self.runtime_server_port = runtime_server_port
self.serve_host = serve_host
self.serve_port = serve_port
self.event_subscriber = event_subscriber

async def run(self):
self.process = self.start_process()
Expand All @@ -79,6 +82,7 @@ def start_process(self):
"runtime_server_port": self.runtime_server_port,
"serve_host": self.serve_host,
"serve_port": self.serve_port,
"event_subscriber": self.event_subscriber,
},
)
process.start()
Expand Down
15 changes: 13 additions & 2 deletions buildflow/core/app/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ def run(
runtime_server_port: int = 9653,
# Options for testing
block: bool = True,
event_subscriber: Optional[Callable] = None,
):
self._add_service_groups()
if not self._processor_groups:
Expand All @@ -773,7 +774,10 @@ def run(
# Setup services
# Start the Flow Runtime
runtime_coroutine = self._run(
debug_run=debug_run, serve_host=serve_host, serve_port=serve_port
debug_run=debug_run,
serve_host=serve_host,
serve_port=serve_port,
event_subscriber=event_subscriber,
)

if debug_run:
Expand Down Expand Up @@ -812,7 +816,13 @@ def run(
else:
return runtime_coroutine

async def _run(self, serve_host: str, serve_port: int, debug_run: bool = False):
async def _run(
self,
serve_host: str,
serve_port: int,
debug_run: bool = False,
event_subscriber: Optional[Callable] = None,
):
# Add a signal handler to drain the runtime when the process is killed
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
Expand All @@ -826,6 +836,7 @@ async def _run(self, serve_host: str, serve_port: int, debug_run: bool = False):
processor_groups=self._processor_groups,
serve_host=serve_host,
serve_port=serve_port,
event_subscriber=event_subscriber,
)
await self._get_runtime_actor().run_until_complete.remote()

Expand Down
30 changes: 30 additions & 0 deletions buildflow/core/app/runtime/_runtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import enum
from typing import Dict, Optional

RunID = str

Expand All @@ -8,10 +10,38 @@ class RuntimeStatus(enum.Enum):
RUNNING = enum.auto()
DRAINING = enum.auto()
DRAINED = enum.auto()
DIED = enum.auto()
STOPPED = enum.auto()
# Used for local reloading
RELOADING = enum.auto()


@dataclasses.dataclass
class RuntimeStatusReport:
status: RuntimeStatus
processor_group_statuses: Dict[str, RuntimeStatus]

def to_dict(self):
return {
"status": self.status.name,
"processor_group_statuses": {
k: v.name for k, v in self.processor_group_statuses.items()
},
}


@dataclasses.dataclass
class RuntimeEvent:
run_id: RunID
status_change: Optional[RuntimeStatusReport]

def to_dict(self):
return {
"run_id": self.run_id,
"status_change": self.status_change.to_dict(),
}


class Snapshot:
status: RuntimeStatus
timestamp_millis: int
Expand Down
7 changes: 1 addition & 6 deletions buildflow/core/app/runtime/actors/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,6 @@ async def add_replicas(self, num_replicas: int):
"this can happen if a drain occurs at the same time as a scale up."
)
return
if self._status != RuntimeStatus.RUNNING:
raise RuntimeError(
"Can only add replicas to a processor pool that is running "
f"was in state: {self._status}."
)
for _ in range(num_replicas):
replica = await self.create_replica()

Expand Down Expand Up @@ -194,7 +189,7 @@ async def drain(self):
logging.info(f"Drain ProcessorPool({self.processor_group.group_id}) complete.")
return True

async def status(self):
async def status(self) -> RuntimeStatus:
return self._status

# NOTE: Subclasses should override this method if they need to provide additional
Expand Down
76 changes: 72 additions & 4 deletions buildflow/core/app/runtime/actors/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import logging
import time
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Type

import ray
from ray.actor import ActorHandle
from ray.exceptions import OutOfMemoryError, RayActorError

from buildflow.core import utils
from buildflow.core.app.runtime._runtime import RunID, Runtime, RuntimeStatus, Snapshot
from buildflow.core.app.runtime._runtime import (
RunID,
Runtime,
RuntimeEvent,
RuntimeStatus,
RuntimeStatusReport,
Snapshot,
)
from buildflow.core.app.runtime.actors.collector_pattern.collector_pool import (
CollectorProcessorPoolActor,
)
Expand Down Expand Up @@ -71,6 +78,7 @@ def __init__(
self._processor_group_pool_refs: List[ProcessorGroupPoolReference] = []
self._runtime_loop_future = None
self.flow_dependencies = flow_dependencies
self._event_subscriber = None

def _set_status(self, status: RuntimeStatus):
self._status = status
Expand Down Expand Up @@ -130,8 +138,10 @@ async def run(
processor_groups: Iterable[ProcessorGroup],
serve_host: str,
serve_port: int,
event_subscriber: Optional[Callable],
):
logging.info("Starting Runtime...")
self._event_subscriber = event_subscriber
self._set_status(RuntimeStatus.RUNNING)
self._processor_group_pool_refs = []
await self.initialize_global_dependencies(processor_groups)
Expand All @@ -154,6 +164,15 @@ async def drain(self, as_reload: bool = False) -> bool:
for processor_pool in self._processor_group_pool_refs
]
# Kill the runtime actor to stop the even loop.
if self._event_subscriber is not None:
event = RuntimeEvent(
self.run_id,
RuntimeStatusReport(
status=RuntimeStatus.STOPPED,
processor_group_statuses={},
),
)
self._event_subscriber(event)
ray.actor.exit_actor()
else:
if as_reload:
Expand Down Expand Up @@ -194,7 +213,11 @@ async def run_until_complete(self):
if self._runtime_loop_future is not None:
await self._runtime_loop_future

async def _runtime_checkin_loop(self, serve_host: str, serve_port: int):
async def _runtime_checkin_loop(
self,
serve_host: str,
serve_port: int,
):
logging.info("Runtime checkin loop started...")
last_autoscale_event = time.monotonic()
# We keep running the loop while the job is running or draining to ensure
Expand All @@ -203,16 +226,24 @@ async def _runtime_checkin_loop(self, serve_host: str, serve_port: int):
# processor types need scaling (e.g. only consumer).
# - one for checking the status (i.e. is it still running)
# - one for autoscaling
previous_status_report = None
while (
self._status == RuntimeStatus.RUNNING
or self._status == RuntimeStatus.DRAINING
):
scaling_coros = []
processor_group_statuses = {}
for processor_pool in self._processor_group_pool_refs:
try:
# Check to see if our processpool actor needs to be restarted.
await processor_pool.actor_handle.status.remote()
status = await processor_pool.actor_handle.status.remote()
processor_group_statuses[
processor_pool.processor_group.group_id
] = status
except (RayActorError, OutOfMemoryError):
processor_group_statuses[
processor_pool.processor_group.group_id
] = RuntimeStatus.DIED
logging.exception("process actor unexpectedly died. will restart.")
if self._status == RuntimeStatus.RUNNING:
# Only restart if we are running, otherwise we are draining
Expand All @@ -222,6 +253,20 @@ async def _runtime_checkin_loop(self, serve_host: str, serve_port: int):
serve_port=serve_port,
)
processor_pool.actor_handle = new_processor_ref.actor_handle
if self._event_subscriber is not None:
status_report = RuntimeStatusReport(
status=self._status,
processor_group_statuses=processor_group_statuses,
)

try:
if previous_status_report != status_report:
event = RuntimeEvent(self.run_id, status_report)
self._event_subscriber(event)
except Exception:
logging.exception("event subscriber failed")
previous_status_report = status_report

processor_options = self.options.processor_options[
processor_pool.processor_group.group_id
]
Expand All @@ -245,3 +290,26 @@ async def _runtime_checkin_loop(self, serve_host: str, serve_port: int):
logging.debug("autoscale check ended at: %s", datetime.utcnow())

await asyncio.sleep(self.options.checkin_frequency_loop_secs)
processor_group_statuses = {}
for processor_pool in self._processor_group_pool_refs:
try:
# Check to see if our processpool actor needs to be restarted.
status = await processor_pool.actor_handle.status.remote()
processor_group_statuses[
processor_pool.processor_group.group_id
] = status
except (RayActorError, OutOfMemoryError):
processor_group_statuses[
processor_pool.processor_group.group_id
] = RuntimeStatus.DIED
if self._event_subscriber is not None:
status_report = RuntimeStatusReport(
status=self._status,
processor_group_statuses=processor_group_statuses,
)

try:
event = RuntimeEvent(self.run_id, status_report)
self._event_subscriber(event)
except Exception:
logging.exception("event subscriber failed")
Loading

0 comments on commit d951778

Please sign in to comment.