Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use quattro #78

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"eth-pydantic-types", # Use same version as eth-ape
"packaging", # Use same version as eth-ape
"pydantic_settings", # Use same version as eth-ape
"quattro>=24.1,<25",
"taskiq[metrics]>=0.11.3,<0.12",
],
entry_points={
Expand Down
13 changes: 9 additions & 4 deletions silverback/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ async def run_worker(broker: AsyncBroker, worker_count=2, shutdown_timeout=90):
callback=_recorder_callback,
)
@click.option("-x", "--max-exceptions", type=int, default=3)
@click.option("--debug", is_flag=True, default=False)
@click.argument("path")
def run(cli_ctx, account, runner_class, recorder, max_exceptions, path):
def run(cli_ctx, account, runner_class, recorder, max_exceptions, debug, path):
if not runner_class:
# NOTE: Automatically select runner class
if cli_ctx.provider.ws_uri:
Expand All @@ -124,7 +125,7 @@ def run(cli_ctx, account, runner_class, recorder, max_exceptions, path):

app = import_from_string(path)
runner = runner_class(app, recorder=recorder, max_exceptions=max_exceptions)
asyncio.run(runner.run())
asyncio.run(runner.run(), debug=debug)


@cli.command(cls=ConnectedProviderCommand, help="Run Silverback application task workers")
Expand All @@ -138,7 +139,11 @@ def run(cli_ctx, account, runner_class, recorder, max_exceptions, path):
@click.option("-w", "--workers", type=int, default=2)
@click.option("-x", "--max-exceptions", type=int, default=3)
@click.option("-s", "--shutdown_timeout", type=int, default=90)
@click.option("--debug", is_flag=True, default=False)
@click.argument("path")
def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, path):
def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, debug, path):
app = import_from_string(path)
asyncio.run(run_worker(app.broker, worker_count=workers, shutdown_timeout=shutdown_timeout))
asyncio.run(
run_worker(app.broker, worker_count=workers, shutdown_timeout=shutdown_timeout),
debug=debug,
)
124 changes: 58 additions & 66 deletions silverback/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import atexit
from abc import ABC, abstractmethod

import quattro
from ape import chain
from ape.logging import logger
from ape.utils import ManagerAccessMixin
Expand Down Expand Up @@ -109,6 +111,10 @@ async def _event_task(self, task_data: TaskData):
handle an event handler task for the given contract event
"""

def _shutdown(self):
asyncio.run(self.app.broker.shutdown(), debug=True)
logger.info("Application shutdown completed")

async def run(self):
"""
Run the task broker client for the assembled ``SilverbackApp`` application.
Expand All @@ -124,6 +130,8 @@ async def run(self):
"""
# Initialize broker (run worker startup events)
await self.app.broker.startup()
# NOTE: Always ensure we shutdown the broker no matter what
atexit.register(self._shutdown)

# Obtain system configuration for worker
result = await run_taskiq_task_wait_result(
Expand All @@ -133,18 +141,18 @@ async def run(self):
raise StartupFailure("Unable to determine system configuration of worker")

# NOTE: Increase the specifier set here if there is a breaking change to this
if Version(result.return_value.sdk_version) not in SpecifierSet(">=0.5.0"):
# TODO: set to next breaking change release before release
if (sdk_version := Version(result.return_value.sdk_version)) not in SpecifierSet(">=0.5.0"):
raise StartupFailure("Worker SDK version too old, please rebuild")

if not (
system_tasks := set(TaskType(task_name) for task_name in result.return_value.task_types)
):
# NOTE: Guaranteed to be at least one because of `TaskType.SYSTEM_CONFIG`
raise StartupFailure("No system tasks detected, startup failure")
# NOTE: Guaranteed to be at least one because of `TaskType.SYSTEM_CONFIG`

system_tasks_str = "\n- ".join(system_tasks)
logger.info(
f"Worker using Silverback SDK v{result.return_value.sdk_version}"
f"Worker using Silverback SDK v{sdk_version}"
f", available task types:\n- {system_tasks_str}"
)

Expand All @@ -163,20 +171,18 @@ async def run(self):
self.state = AppState(last_block_seen=-1, last_block_processed=-1)

# Execute Silverback startup task before we init the rest
startup_taskdata_result = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.STARTUP
)

if startup_taskdata_result.is_err:
raise StartupFailure(startup_taskdata_result.error)

else:
startup_task_handlers = map(
self._create_task_kicker, startup_taskdata_result.return_value
if (
startup_taskdata_result := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.STARTUP
)
).is_err:
raise StartupFailure(startup_taskdata_result.error)

elif startup_task_handlers := tuple(
map(self._create_task_kicker, startup_taskdata_result.return_value)
):
startup_task_results = await run_taskiq_task_group_wait_results(
(task_handler for task_handler in startup_task_handlers), self.state
startup_task_handlers, self.state
)

if any(result.is_err for result in startup_task_results):
Expand All @@ -187,21 +193,26 @@ async def run(self):

elif self.recorder:
converted_results = map(TaskResult.from_taskiq, startup_task_results)
await asyncio.gather(*(self.recorder.add_result(r) for r in converted_results))
await quattro.gather(*(self.recorder.add_result(r) for r in converted_results))

# NOTE: No need to handle results otherwise
# else: No need to handle results otherwise

else:
logger.info("No startup tasks detected")

# Create our long-running event listeners
new_block_taskdata_results = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK
)
if new_block_taskdata_results.is_err:
if (
new_block_taskdata_results := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK
)
).is_err:
raise StartupFailure(new_block_taskdata_results.error)

event_log_taskdata_results = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.EVENT_LOG
)
if event_log_taskdata_results.is_err:
if (
event_log_taskdata_results := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.EVENT_LOG
)
).is_err:
raise StartupFailure(event_log_taskdata_results.error)

if (
Expand All @@ -212,50 +223,28 @@ async def run(self):
raise NoTasksAvailableError()

# NOTE: Any propagated failure in here should be handled such that shutdown tasks also run
# TODO: `asyncio.TaskGroup` added in Python 3.11
listener_tasks = (
*(
asyncio.create_task(self._block_task(task_def))
for task_def in new_block_taskdata_results.return_value
),
*(
asyncio.create_task(self._event_task(task_def))
for task_def in event_log_taskdata_results.return_value
),
)

# NOTE: Safe to do this because no tasks were actually scheduled to run
if len(listener_tasks) == 0:
raise NoTasksAvailableError()

# Run until one task bubbles up an exception that should stop execution
tasks_with_errors, tasks_running = await asyncio.wait(
listener_tasks, return_when=asyncio.FIRST_EXCEPTION
exceptions_or_none = await quattro.gather(
*(self._block_task(task_def) for task_def in new_block_taskdata_results.return_value),
*(self._event_task(task_def) for task_def in event_log_taskdata_results.return_value),
return_exceptions=True,
)
if runtime_errors := "\n".join(str(task.exception()) for task in tasks_with_errors):
# NOTE: In case we are somehow not displaying the error correctly with task status
logger.debug(f"Runtime error(s) detected, shutting down:\n{runtime_errors}")

# Cancel any still running
(task.cancel() for task in tasks_running)
# NOTE: All listener tasks are shut down now
# NOTE: Result is either None or Exception
if err_msg := "\n\n".join(str(e) for e in exceptions_or_none if e):
logger.error(f"Runtime error(s) detected, shutting down:\n{err_msg}")

# Execute Silverback shutdown task(s) before shutting down the broker and app
shutdown_taskdata_result = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.SHUTDOWN
)

if shutdown_taskdata_result.is_err:
raise StartupFailure(shutdown_taskdata_result.error)

else:
shutdown_task_handlers = map(
self._create_task_kicker, shutdown_taskdata_result.return_value
if (
shutdown_taskdata_result := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.SHUTDOWN
)
).is_err:
raise RuntimeError(shutdown_taskdata_result.error)

shutdown_task_results = await run_taskiq_task_group_wait_results(
(task_handler for task_handler in shutdown_task_handlers)
)
elif shutdown_task_handlers := tuple(
map(self._create_task_kicker, shutdown_taskdata_result.return_value)
):
shutdown_task_results = await run_taskiq_task_group_wait_results(shutdown_task_handlers)

if any(result.is_err for result in shutdown_task_results):
errors_str = "\n".join(
Expand All @@ -265,11 +254,14 @@ async def run(self):

elif self.recorder:
converted_results = map(TaskResult.from_taskiq, shutdown_task_results)
await asyncio.gather(*(self.recorder.add_result(r) for r in converted_results))
await quattro.gather(*(self.recorder.add_result(r) for r in converted_results))

# else: No need to handle results otherwise

# NOTE: No need to handle results otherwise
else:
logger.info("No shutdown tasks detected")

await self.app.broker.shutdown()
# NOTE: atexit handles self.app.broker.shutdown()


class WebsocketRunner(BaseRunner, ManagerAccessMixin):
Expand Down
54 changes: 28 additions & 26 deletions silverback/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import json
from collections import defaultdict
from enum import Enum
from typing import AsyncGenerator

import quattro
from ape.logging import logger
from websockets import ConnectionClosedError
from websockets import client as ws_client
Expand All @@ -28,25 +30,25 @@ def __init__(self, ws_provider_uri: str):
# Stateful
self._connection: ws_client.WebSocketClientProtocol | None = None
self._last_request: int = 0
self._subscriptions: dict[str, asyncio.Queue] = {}
self._subscriptions: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)
self._rpc_msg_buffer: list[dict] = []
self._ws_lock = asyncio.Lock()

def __repr__(self) -> str:
return f"<{self.__class__.__name__} uri={self._ws_provider_uri}>"

async def __aenter__(self) -> "Web3SubscriptionsManager":
self.connection = await ws_client.connect(self._ws_provider_uri)
self._connection = await ws_client.connect(self._ws_provider_uri)
return self

def __aiter__(self) -> "Web3SubscriptionsManager":
return self

async def __anext__(self) -> str:
if not self.connection:
if not self._connection:
raise StopAsyncIteration

message = await self.connection.recv()
message = await self._connection.recv()
# TODO: Handle retries when connection breaks

response = json.loads(message)
Expand All @@ -56,9 +58,6 @@ async def __anext__(self) -> str:
logger.debug(f"Corrupted subscription data: {response}")
return response

if sub_id not in self._subscriptions:
self._subscriptions[sub_id] = asyncio.Queue()

await self._subscriptions[sub_id].put(sub_params.get("result", {}))

else:
Expand Down Expand Up @@ -94,7 +93,7 @@ async def _get_response(self, request_id: int) -> dict:
raise RuntimeError("Timeout waiting for response.")

async def subscribe(self, type: SubscriptionType, **filter_params) -> str:
if not self.connection:
if not self._connection:
raise ValueError("Connection required.")

if type is SubscriptionType.BLOCKS and filter_params:
Expand All @@ -104,7 +103,7 @@ async def subscribe(self, type: SubscriptionType, **filter_params) -> str:
"eth_subscribe",
[type.value, filter_params] if type is SubscriptionType.EVENTS else [type.value],
)
await self.connection.send(json.dumps(request))
await self._connection.send(json.dumps(request))
response = await self._get_response(request.get("id") or self._last_request)

sub_id = response.get("result")
Expand All @@ -116,24 +115,27 @@ async def subscribe(self, type: SubscriptionType, **filter_params) -> str:

async def get_subscription_data(self, sub_id: str) -> AsyncGenerator[dict, None]:
while True:
if not (queue := self._subscriptions.get(sub_id)) or queue.empty():
if self._subscriptions[sub_id].empty():
async with self._ws_lock:
# Keep pulling until a message comes to process
# NOTE: Python <3.10 does not support `anext` function
await self.__anext__()
else:
yield await queue.get()
yield await self._subscriptions[sub_id].get()

async def unsubscribe(self, sub_id: str) -> bool:
if sub_id not in self._subscriptions:
raise ValueError(f"Unknown sub_id '{sub_id}'")

if not self.connection:
if not self._connection:
# Nothing to unsubscribe.
return True

request = self._create_request("eth_unsubscribe", [sub_id])
await self.connection.send(json.dumps(request))
try:
await self._connection.send(json.dumps(request))
except ConnectionClosedError:
return False

response = await self._get_response(request.get("id") or self._last_request)
if success := response.get("result", False):
Expand All @@ -142,16 +144,16 @@ async def unsubscribe(self, sub_id: str) -> bool:
return success

async def __aexit__(self, exc_type, exc, tb):
try:
# Try to gracefully unsubscribe to all events
await asyncio.gather(*(self.unsubscribe(sub_id) for sub_id in self._subscriptions))

except ConnectionClosedError:
pass # Websocket already closed (ctrl+C and patiently waiting)

finally:
# Disconnect and release websocket
try:
await self.connection.close()
except RuntimeError:
pass # No running event loop to disconnect from (multiple ctrl+C presses)
if not all(
is_successful is True
for is_successful in await quattro.gather(
# Try to gracefully unsubscribe to all events
*(self.unsubscribe(sub_id) for sub_id in self._subscriptions),
# NOTE: Do not catch error
return_exceptions=True,
)
):
logger.debug("Failed to unsubscribe from all tasks")

# Disconnect and release websocket
await self._connection.close()
Loading