diff --git a/bittensor/axon.py b/bittensor/axon.py index 8cefadfe61..0a62d6c84b 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -26,6 +26,7 @@ import inspect import json import os +import socket import threading import time import traceback @@ -62,6 +63,12 @@ from bittensor.utils import networking +""" +The quantum of time to sleep in waiting loops, in seconds. +""" +TIME_SLEEP_INTERVAL: float = 1e-3 + + class FastAPIThreadedServer(uvicorn.Server): """ The ``FastAPIThreadedServer`` class is a specialized server implementation for the Axon server in the Bittensor network. @@ -100,26 +107,80 @@ class FastAPIThreadedServer(uvicorn.Server): should_exit: bool = False is_running: bool = False + """ + Provide a channel to signal exceptions from the thread to our caller. + """ + _exception: Optional[Exception] = None + _lock: threading.Lock = threading.Lock() + _thread: Optional[threading.Thread] = None + _started: bool = False + + def set_exception(self, exception: Exception) -> None: + """ + Set self._exception in a thread safe manner, so the worker thread can communicate exceptions to the main thread. + """ + with self._lock: + self._exception = exception + + def get_exception(self) -> Optional[Exception]: + with self._lock: + return self._exception + + def set_thread(self, thread: threading.Thread): + """ + Set self._thread in a thread safe manner, so the main thread can get the worker thread object. + """ + with self._lock: + self._thread = thread + + def get_thread(self) -> Optional[threading.Thread]: + with self._lock: + return self._thread + + def set_started(self, started: bool) -> None: + """ + Set self._started in a thread safe manner, so the main thread can get the worker thread status. + """ + with self._lock: + self._started = started + + def get_started(self) -> bool: + with self._lock: + return self._started + def install_signal_handlers(self): """ Overrides the default signal handlers provided by ``uvicorn.Server``. This method is essential to ensure that the signal handling in the threaded server does not interfere with the main application's flow, especially in a complex asynchronous environment like the Axon server. """ pass + async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None: + """ + Adds a thread-safe call to set a 'started' flag on the object. + """ + await super().startup(sockets) + self.set_started(True) + @contextlib.contextmanager def run_in_thread(self): """ Manages the execution of the server in a separate thread, allowing the FastAPI application to run asynchronously without blocking the main thread of the Axon server. This method is a key component in enabling concurrent request handling in the Axon server. Yields: - None: This method yields control back to the caller while the server is running in the background thread. + thread: a running thread + + Raises: + Exception: in case the server did not start (as signalled by self.get_started()) """ thread = threading.Thread(target=self.run, daemon=True) thread.start() try: - while not self.started: - time.sleep(1e-3) - yield + time_start = time.time() + while not self.get_started() and time.time() - time_start < 1: + time.sleep(TIME_SLEEP_INTERVAL) + if not self.get_started(): + raise Exception("failed to start server") + yield thread finally: self.should_exit = True thread.join() @@ -128,9 +189,15 @@ def _wrapper_run(self): """ A wrapper method for the :func:`run_in_thread` context manager. This method is used internally by the ``start`` method to initiate the server's execution in a separate thread. """ - with self.run_in_thread(): - while not self.should_exit: - time.sleep(1e-3) + try: + with self.run_in_thread() as thread: + self.set_thread(thread) + while not self.should_exit: + if not thread.is_alive(): + raise Exception("worker thread died") + time.sleep(TIME_SLEEP_INTERVAL) + except Exception as e: + self.set_exception(e) def start(self): """ @@ -405,6 +472,26 @@ def info(self) -> "bittensor.AxonInfo": placeholder2=0, ) + @property + def exception(self) -> Optional[Exception]: + """ + Axon objects expose exceptions that occurred internally through the .exception property. + """ + # for future use: setting self._exception to signal an exception + exception = getattr(self, "_exception", None) + if exception: + return exception + return self.fast_server.get_exception() + + def is_running(self) -> bool: + """ + Axon objects can be queried using .is_running() to test whether worker threads are running. + """ + thread = self.fast_server.get_thread() + if thread is None: + return False + return thread.is_alive() + def attach( self, forward_fn: Callable, diff --git a/bittensor/types.py b/bittensor/bt_types.py similarity index 100% rename from bittensor/types.py rename to bittensor/bt_types.py diff --git a/bittensor/commands/stake.py b/bittensor/commands/stake.py index 132529a131..eff415d1a1 100644 --- a/bittensor/commands/stake.py +++ b/bittensor/commands/stake.py @@ -44,19 +44,25 @@ def get_netuid( - cli: "bittensor.cli", subtensor: "bittensor.subtensor" + cli: "bittensor.cli", subtensor: "bittensor.subtensor", prompt: bool = True ) -> Tuple[bool, int]: """Retrieve and validate the netuid from the user or configuration.""" console = Console() - if not cli.config.is_set("netuid"): - try: - cli.config.netuid = int(Prompt.ask("Enter netuid")) - except ValueError: - console.print( - "[red]Invalid input. Please enter a valid integer for netuid.[/red]" - ) - return False, -1 + if not cli.config.is_set("netuid") and prompt: + cli.config.netuid = Prompt.ask("Enter netuid") + try: + cli.config.netuid = int(cli.config.netuid) + except ValueError: + console.print( + "[red]Invalid input. Please enter a valid integer for netuid.[/red]" + ) + return False, -1 netuid = cli.config.netuid + if netuid < 0 or netuid > 65535: + console.print( + "[red]Invalid input. Please enter a valid integer for netuid in subnet range.[/red]" + ) + return False, -1 if not subtensor.subnet_exists(netuid=netuid): console.print( "[red]Network with netuid {} does not exist. Please try again.[/red]".format( @@ -1136,10 +1142,27 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): wallet = bittensor.wallet(config=cli.config) # check all - if not cli.config.is_set("all"): - exists, netuid = get_netuid(cli, subtensor) - if not exists: - return + if cli.config.is_set("all"): + cli.config.netuid = None + cli.config.all = True + elif cli.config.is_set("netuid"): + if cli.config.netuid == "all": + cli.config.all = True + else: + cli.config.netuid = int(cli.config.netuid) + exists, netuid = get_netuid(cli, subtensor) + if not exists: + return + else: + netuid_input = Prompt.ask("Enter netuid or 'all'", default="all") + if netuid_input == "all": + cli.config.netuid = None + cli.config.all = True + else: + cli.config.netuid = int(netuid_input) + exists, netuid = get_netuid(cli, subtensor, False) + if not exists: + return # get parent hotkey hotkey = get_hotkey(wallet, cli.config) @@ -1148,11 +1171,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): return try: - netuids = ( - subtensor.get_all_subnet_netuids() - if cli.config.is_set("all") - else [netuid] - ) + netuids = subtensor.get_all_subnet_netuids() if cli.config.all else [netuid] hotkey_stake = GetChildrenCommand.get_parent_stake_info( console, subtensor, hotkey ) @@ -1236,7 +1255,7 @@ def add_args(parser: argparse.ArgumentParser): parser = parser.add_parser( "get_children", help="""Get child hotkeys on subnet.""" ) - parser.add_argument("--netuid", dest="netuid", type=int, required=False) + parser.add_argument("--netuid", dest="netuid", type=str, required=False) parser.add_argument("--hotkey", dest="hotkey", type=str, required=False) parser.add_argument( "--all", @@ -1294,7 +1313,7 @@ def render_table( # Add columns to the table with specific styles table.add_column("Index", style="bold yellow", no_wrap=True, justify="center") - table.add_column("ChildHotkey", style="bold green") + table.add_column("Child Hotkey", style="bold green") table.add_column("Proportion", style="bold cyan", no_wrap=True, justify="right") table.add_column( "Childkey Take", style="bold blue", no_wrap=True, justify="right" diff --git a/bittensor/subtensor.py b/bittensor/subtensor.py index ac22a3a14d..cc8f8d10c8 100644 --- a/bittensor/subtensor.py +++ b/bittensor/subtensor.py @@ -109,7 +109,7 @@ unstake_extrinsic, unstake_multiple_extrinsic, ) -from .types import AxonServeCallParams, PrometheusServeCallParams +from .bt_types import AxonServeCallParams, PrometheusServeCallParams from .utils import ( U16_NORMALIZED_FLOAT, ss58_to_vec_u8,