From 2ffdafd51f5aaf8560027707dc3993969988f4b9 Mon Sep 17 00:00:00 2001 From: Samaksh Gupta Date: Sun, 17 Apr 2022 18:03:34 +0530 Subject: [PATCH] Added `wait_for` and its shorthand methods. Co-authored-by: Black Thunder --- winerp/client.py | 97 ++++++++++++++++++++++++++++++++++++++++++++---- winerp/server.py | 1 - 2 files changed, 89 insertions(+), 9 deletions(-) diff --git a/winerp/client.py b/winerp/client.py index df1d35c..c7d9573 100644 --- a/winerp/client.py +++ b/winerp/client.py @@ -14,6 +14,8 @@ Coroutine, TypeVar, Union, + Dict, + Tuple, ) logger = logging.getLogger(__name__) @@ -28,8 +30,6 @@ class Client: local_name: :class:`str` The name which will be used to refer to this client. This should be unique to all the clients. - loop: Optional[:class:`asyncio.AbstractEventLoop`] - The asyncio loop to use. port: Optional[:class:`int`] The port on which the server is running. Defaults to 13254. """ @@ -41,8 +41,9 @@ def __init__( self.uri = f"ws://localhost:{port}" self.local_name = local_name self.websocket = None - self.routes = {} + self.__routes = {} self.listeners = {} + self.event_listeners: Dict[str, Tuple[asyncio.Future, Callable]] = {} self._authorized = False self._on_hold = False self.events = [ @@ -127,13 +128,13 @@ def route(self, name: str = None): The function passed is not a coro. ''' def route_decorator(func): - if (name is None and func.__name__ in self.routes) or (name is not None and name in self.routes): - raise ValueError("Route name already exists!") + if (name is None and func.__name__ in self.__routes) or (name is not None and name in self.__routes): + raise ValueError("Route name is already registered!") if not asyncio.iscoroutinefunction(func): raise InvalidRouteType("Route function must be a coro.") - self.routes[name or func.__name__] = func + self.__routes[name or func.__name__] = func return func if isinstance(name, FunctionType): @@ -304,6 +305,77 @@ async def inform( await self.send_message(payload) else: raise ClientNotReadyError("The client has not been started or has disconnected") + + + async def wait_until_ready(self): + '''|coro| + + Waits until the client is ready to send or accept requests. + ''' + await self.wait_for('winerp_ready', None) + + async def wait_until_disconnected(self): + '''|coro| + + Waits until the client is disconnected. + ''' + await self.wait_for('winerp_disconnect', None) + + def wait_for( + self, + event: str, + timeout: int = 60, + check: Callable = None, + ): + '''|coro| + + Waits for a WebSocket event to be dispatched. + + The timeout parameter is passed onto asyncio.wait_for(). + By default, it does not timeout. + + In case the event returns multiple arguments, a tuple containing those arguments is returned instead. + Please check the documentation for a list of events and their parameters. + + This function returns the **first event that meets the requirements.** + + Parameters + ----------- + event: :class:`str` + The event to wait for. + timeout: Optional[:class:`int`] + Time to wait before raising :class:`~asyncio.TimeoutError`. Defaults to 60. + check: Optional[:class:`Callable`] + A function to check if the event meets the requirements. + If it returns True, the event is returned. + + Raises + ------- + asyncio.TimeoutError + If the event is not received within the timeout. + + Returns + -------- + :class:`Any` + The payload for the event that meets the requirements. + ''' + future = asyncio.get_event_loop().create_future() + if check is None: + + def _check(*args): + return True + + check = _check + + ev = event.lower() + try: + listeners = self.event_listeners[ev] + except KeyError: + listeners = [] + self.event_listeners[ev] = listeners + + listeners.append((future, check)) + return asyncio.wait_for(future, timeout) async def __on_message(self): @@ -325,7 +397,7 @@ async def __on_message(self): asyncio.create_task(self._dispatch(message)) elif message.type.request: - if message.route not in self.routes: + if message.route not in self.__routes: logger.info("Failed to fulfill request, route not found") payload = MessagePayload( type = Payloads.error, @@ -366,7 +438,7 @@ async def __on_message(self): async def _fulfill_request(self, message: WsMessage): route = message.route - func = self.routes[route] + func = self.__routes[route] data = message.data payload = MessagePayload().from_message(message) payload.type = Payloads.response @@ -455,6 +527,15 @@ def event(self, func: Coro, /) -> Coro: return func def _dispatch_event(self, event_name: str, *args, **kwargs): + logger.debug('Event Dispatch -> %r', event_name) + + for ev, data in self.event_listeners.items(): + if ev == event_name: + for fut, check in data: + if check(*args, **kwargs): + fut.set_result(None) + logger.debug('Event %r has been dispatched', event_name) + try: coro = getattr(self, f'on_{event_name}') except AttributeError: diff --git a/winerp/server.py b/winerp/server.py index 227c20c..2fb716a 100644 --- a/winerp/server.py +++ b/winerp/server.py @@ -29,7 +29,6 @@ def __init__( self.websocket.set_fn_new_client(self.__on_client_connect) self.websocket.set_fn_message_received(self.__on_message) self.websocket.set_fn_client_left(self.__on_client_disconnect) - self.console_output = True self.active_clients = {} self.pending_verification = {} self.on_hold_connections = {}