From 60b6eab13cb2b55e891ab21e3d3652b182071066 Mon Sep 17 00:00:00 2001 From: Devoxin Date: Fri, 17 Feb 2023 22:16:35 +0000 Subject: [PATCH] More typings --- lavalink/events.py | 70 +++++++++++++++++++++------------------ lavalink/node.py | 25 +++++++------- lavalink/nodemanager.py | 9 ++--- lavalink/playermanager.py | 9 ++--- 4 files changed, 60 insertions(+), 53 deletions(-) diff --git a/lavalink/events.py b/lavalink/events.py index c5cdb141..9b147b07 100644 --- a/lavalink/events.py +++ b/lavalink/events.py @@ -21,6 +21,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + # pylint: disable=cyclic-import + from .models import AudioTrack, BasePlayer, DeferredAudioTrack + from .node import Node class Event: @@ -41,8 +47,8 @@ class TrackStartEvent(Event): __slots__ = ('player', 'track') def __init__(self, player, track): - self.player = player - self.track = track + self.player: 'BasePlayer' = player + self.track: 'AudioTrack' = track class TrackStuckEvent(Event): @@ -63,9 +69,9 @@ class TrackStuckEvent(Event): __slots__ = ('player', 'track', 'threshold') def __init__(self, player, track, threshold): - self.player = player - self.track = track - self.threshold = threshold + self.player: 'BasePlayer' = player + self.track: 'AudioTrack' = track + self.threshold: int = threshold class TrackExceptionEvent(Event): @@ -86,10 +92,10 @@ class TrackExceptionEvent(Event): __slots__ = ('player', 'track', 'exception', 'severity') def __init__(self, player, track, exception, severity): - self.player = player - self.track = track - self.exception = exception - self.severity = severity + self.player: 'BasePlayer' = player + self.track: 'AudioTrack' = track + self.exception: str = exception + self.severity: str = severity class TrackEndEvent(Event): @@ -109,9 +115,9 @@ class TrackEndEvent(Event): __slots__ = ('player', 'track', 'reason') def __init__(self, player, track, reason): - self.player = player - self.track = track - self.reason = reason + self.player: 'BasePlayer' = player + self.track: Optional['AudioTrack'] = track + self.reason: str = reason class TrackLoadFailedEvent(Event): @@ -134,9 +140,9 @@ class TrackLoadFailedEvent(Event): __slots__ = ('player', 'track', 'original') def __init__(self, player, track, original): - self.player = player - self.track = track - self.original = original + self.player: 'BasePlayer' = player + self.track: 'DeferredAudioTrack' = track + self.original: Optional[Exception] = original class QueueEndEvent(Event): @@ -152,7 +158,7 @@ class QueueEndEvent(Event): __slots__ = ('player',) def __init__(self, player): - self.player = player + self.player: 'BasePlayer' = player class PlayerUpdateEvent(Event): @@ -177,11 +183,11 @@ class PlayerUpdateEvent(Event): __slots__ = ('player', 'position', 'timestamp', 'connected', 'ping') def __init__(self, player, raw_state): - self.player = player - self.position = raw_state.get('position') - self.timestamp = raw_state.get('time') - self.connected = raw_state.get('connected') - self.ping = raw_state.get('ping', -1) + self.player: 'BasePlayer' = player + self.position: int = raw_state.get('position') + self.timestamp: int = raw_state.get('time') + self.connected: bool = raw_state.get('connected') + self.ping: int = raw_state.get('ping', -1) class NodeConnectedEvent(Event): @@ -197,7 +203,7 @@ class NodeConnectedEvent(Event): __slots__ = ('node',) def __init__(self, node): - self.node = node + self.node: 'Node' = node class NodeDisconnectedEvent(Event): @@ -217,9 +223,9 @@ class NodeDisconnectedEvent(Event): __slots__ = ('node', 'code', 'reason') def __init__(self, node, code, reason): - self.node = node - self.code = code - self.reason = reason + self.node: 'Node' = node + self.code: Optional[int] = code + self.reason: Optional[str] = reason class NodeChangedEvent(Event): @@ -240,9 +246,9 @@ class NodeChangedEvent(Event): __slots__ = ('player', 'old_node', 'new_node') def __init__(self, player, old_node, new_node): - self.player = player - self.old_node = old_node - self.new_node = new_node + self.player: 'BasePlayer' = player + self.old_node: 'Node' = old_node + self.new_node: 'Node' = new_node class WebSocketClosedEvent(Event): @@ -269,7 +275,7 @@ class WebSocketClosedEvent(Event): __slots__ = ('player', 'code', 'reason', 'by_remote') def __init__(self, player, code, reason, by_remote): - self.player = player - self.code = code - self.reason = reason - self.by_remote = by_remote + self.player: 'BasePlayer' = player + self.code: int = code + self.reason: str = reason + self.by_remote: bool = by_remote diff --git a/lavalink/node.py b/lavalink/node.py index 101e4695..dfb708bc 100644 --- a/lavalink/node.py +++ b/lavalink/node.py @@ -23,9 +23,8 @@ """ from typing import List -from lavalink.models import Plugin - from .events import Event +from .models import BasePlayer, LoadResult, Plugin # noqa: F401 from .stats import Stats from .websocket import WebSocket @@ -66,14 +65,14 @@ def __init__(self, manager, host: str, port: int, password: str, self._manager = manager self._ws = WebSocket(self, host, port, password, ssl, resume_key, resume_timeout, reconnect_attempts) - self.host = host - self.port = port - self.password = password - self.ssl = ssl - self.region = region - self.name = name or '{}-{}:{}'.format(self.region, self.host, self.port) - self.filters = filters - self.stats = Stats.empty(self) + self.host: str = host + self.port: int = port + self.password: str = password + self.ssl: bool = ssl + self.region: str = region + self.name: str = name or '{}-{}:{}'.format(self.region, self.host, self.port) + self.filters: bool = filters + self.stats: Stats = Stats.empty(self) @property def available(self) -> bool: @@ -81,7 +80,7 @@ def available(self) -> bool: return self._ws.connected @property - def _original_players(self): + def _original_players(self) -> List[BasePlayer]: """ Returns a list of players that were assigned to this node, but were moved due to failover etc. @@ -92,7 +91,7 @@ def _original_players(self): return [p for p in self._lavalink.player_manager.values() if p._original_node == self] @property - def players(self): + def players(self) -> List[BasePlayer]: """ Returns a list of all players on this node. @@ -122,7 +121,7 @@ async def destroy(self): """ await self._ws.destroy() - async def get_tracks(self, query: str, check_local: bool = False): + async def get_tracks(self, query: str, check_local: bool = False) -> LoadResult: """|coro| Retrieves a list of results pertaining to the provided query. diff --git a/lavalink/nodemanager.py b/lavalink/nodemanager.py index 52ca9723..1b1ad080 100644 --- a/lavalink/nodemanager.py +++ b/lavalink/nodemanager.py @@ -22,6 +22,7 @@ SOFTWARE. """ import logging +from typing import List, Optional from .events import NodeConnectedEvent, NodeDisconnectedEvent from .node import Node @@ -63,7 +64,7 @@ def __iter__(self): yield n @property - def available_nodes(self): + def available_nodes(self) -> List[Node]: """ Returns a list of available nodes. """ return [n for n in self.nodes if n.available] @@ -123,7 +124,7 @@ def remove_node(self, node: Node): self.nodes.remove(node) _log.info('Removed node \'%s\'', node.name) - def get_nodes_by_region(self, region_key: str): + def get_nodes_by_region(self, region_key: str) -> List[Node]: """ Get a list of nodes by their region. This does not account for node availability, so the nodes returned @@ -143,7 +144,7 @@ def get_nodes_by_region(self, region_key: str): """ return [n for n in self.nodes if n.region == region_key] - def get_region(self, endpoint: str): + def get_region(self, endpoint: str) -> str: """ Returns a Lavalink.py-friendly region from a Discord voice server address. @@ -170,7 +171,7 @@ def get_region(self, endpoint: str): return None - def find_ideal_node(self, region: str = None): + def find_ideal_node(self, region: str = None) -> Optional[Node]: """ Finds the best (least used) node in the given region, if applicable. diff --git a/lavalink/playermanager.py b/lavalink/playermanager.py index dc4f62ad..c63c8bb0 100644 --- a/lavalink/playermanager.py +++ b/lavalink/playermanager.py @@ -22,6 +22,7 @@ SOFTWARE. """ import logging +from typing import Callable, Dict, Iterator from .errors import NodeError from .models import BasePlayer @@ -52,7 +53,7 @@ def __init__(self, lavalink, player): self._lavalink = lavalink self._player_cls = player - self.players = {} + self.players: Dict[int, BasePlayer] = {} def __len__(self): return len(self.players) @@ -62,18 +63,18 @@ def __iter__(self): for guild_id, player in self.players.items(): yield guild_id, player - def values(self): + def values(self) -> Iterator[BasePlayer]: """ Returns an iterator that yields only values. """ for player in self.players.values(): yield player - def find_all(self, predicate=None): + def find_all(self, predicate: Callable[[BasePlayer], bool] = None): """ Returns a list of players that match the given predicate. Parameters ---------- - predicate: Optional[:class:`function`] + predicate: Optional[Callable[[:class:BasePlayer], bool]] A predicate to return specific players. Defaults to ``None``. Returns