Skip to content

Commit

Permalink
Merge pull request #2 from BlackThunder01001/wait_until_auth
Browse files Browse the repository at this point in the history
Added `wait_for`, `wait_until_ready` and `wait_until_disconnected`.
  • Loading branch information
AwesomeSam9523 authored Apr 17, 2022
2 parents 32b9677 + 2ffdafd commit 7812a24
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
97 changes: 89 additions & 8 deletions winerp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
Coroutine,
TypeVar,
Union,
Dict,
Tuple,
)

logger = logging.getLogger(__name__)
Expand All @@ -28,8 +30,6 @@ class Client:
local_name: :class:`str`
The name which will be used to refer to this client.
This should be unique to all the clients.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
The asyncio loop to use.
port: Optional[:class:`int`]
The port on which the server is running. Defaults to 13254.
"""
Expand All @@ -41,8 +41,9 @@ def __init__(
self.uri = f"ws://localhost:{port}"
self.local_name = local_name
self.websocket = None
self.routes = {}
self.__routes = {}
self.listeners = {}
self.event_listeners: Dict[str, Tuple[asyncio.Future, Callable]] = {}
self._authorized = False
self._on_hold = False
self.events = [
Expand Down Expand Up @@ -127,13 +128,13 @@ def route(self, name: str = None):
The function passed is not a coro.
'''
def route_decorator(func):
if (name is None and func.__name__ in self.routes) or (name is not None and name in self.routes):
raise ValueError("Route name already exists!")
if (name is None and func.__name__ in self.__routes) or (name is not None and name in self.__routes):
raise ValueError("Route name is already registered!")

if not asyncio.iscoroutinefunction(func):
raise InvalidRouteType("Route function must be a coro.")

self.routes[name or func.__name__] = func
self.__routes[name or func.__name__] = func
return func

if isinstance(name, FunctionType):
Expand Down Expand Up @@ -304,6 +305,77 @@ async def inform(
await self.send_message(payload)
else:
raise ClientNotReadyError("The client has not been started or has disconnected")


async def wait_until_ready(self):
'''|coro|
Waits until the client is ready to send or accept requests.
'''
await self.wait_for('winerp_ready', None)

async def wait_until_disconnected(self):
'''|coro|
Waits until the client is disconnected.
'''
await self.wait_for('winerp_disconnect', None)

def wait_for(
self,
event: str,
timeout: int = 60,
check: Callable = None,
):
'''|coro|
Waits for a WebSocket event to be dispatched.
The timeout parameter is passed onto asyncio.wait_for().
By default, it does not timeout.
In case the event returns multiple arguments, a tuple containing those arguments is returned instead.
Please check the documentation for a list of events and their parameters.
This function returns the **first event that meets the requirements.**
Parameters
-----------
event: :class:`str`
The event to wait for.
timeout: Optional[:class:`int`]
Time to wait before raising :class:`~asyncio.TimeoutError`. Defaults to 60.
check: Optional[:class:`Callable`]
A function to check if the event meets the requirements.
If it returns True, the event is returned.
Raises
-------
asyncio.TimeoutError
If the event is not received within the timeout.
Returns
--------
:class:`Any`
The payload for the event that meets the requirements.
'''
future = asyncio.get_event_loop().create_future()
if check is None:

def _check(*args):
return True

check = _check

ev = event.lower()
try:
listeners = self.event_listeners[ev]
except KeyError:
listeners = []
self.event_listeners[ev] = listeners

listeners.append((future, check))
return asyncio.wait_for(future, timeout)


async def __on_message(self):
Expand All @@ -325,7 +397,7 @@ async def __on_message(self):
asyncio.create_task(self._dispatch(message))

elif message.type.request:
if message.route not in self.routes:
if message.route not in self.__routes:
logger.info("Failed to fulfill request, route not found")
payload = MessagePayload(
type = Payloads.error,
Expand Down Expand Up @@ -366,7 +438,7 @@ async def __on_message(self):

async def _fulfill_request(self, message: WsMessage):
route = message.route
func = self.routes[route]
func = self.__routes[route]
data = message.data
payload = MessagePayload().from_message(message)
payload.type = Payloads.response
Expand Down Expand Up @@ -455,6 +527,15 @@ def event(self, func: Coro, /) -> Coro:
return func

def _dispatch_event(self, event_name: str, *args, **kwargs):
logger.debug('Event Dispatch -> %r', event_name)

for ev, data in self.event_listeners.items():
if ev == event_name:
for fut, check in data:
if check(*args, **kwargs):
fut.set_result(None)
logger.debug('Event %r has been dispatched', event_name)

try:
coro = getattr(self, f'on_{event_name}')
except AttributeError:
Expand Down
1 change: 0 additions & 1 deletion winerp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
self.websocket.set_fn_new_client(self.__on_client_connect)
self.websocket.set_fn_message_received(self.__on_message)
self.websocket.set_fn_client_left(self.__on_client_disconnect)
self.console_output = True
self.active_clients = {}
self.pending_verification = {}
self.on_hold_connections = {}
Expand Down

0 comments on commit 7812a24

Please sign in to comment.