Skip to content

Commit

Permalink
More typings
Browse files Browse the repository at this point in the history
  • Loading branch information
devoxin committed Feb 17, 2023
1 parent a6719b9 commit 60b6eab
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 53 deletions.
70 changes: 38 additions & 32 deletions lavalink/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .models import AudioTrack, BasePlayer, DeferredAudioTrack
from .node import Node


class Event:
Expand All @@ -41,8 +47,8 @@ class TrackStartEvent(Event):
__slots__ = ('player', 'track')

def __init__(self, player, track):
self.player = player
self.track = track
self.player: 'BasePlayer' = player
self.track: 'AudioTrack' = track


class TrackStuckEvent(Event):
Expand All @@ -63,9 +69,9 @@ class TrackStuckEvent(Event):
__slots__ = ('player', 'track', 'threshold')

def __init__(self, player, track, threshold):
self.player = player
self.track = track
self.threshold = threshold
self.player: 'BasePlayer' = player
self.track: 'AudioTrack' = track
self.threshold: int = threshold


class TrackExceptionEvent(Event):
Expand All @@ -86,10 +92,10 @@ class TrackExceptionEvent(Event):
__slots__ = ('player', 'track', 'exception', 'severity')

def __init__(self, player, track, exception, severity):
self.player = player
self.track = track
self.exception = exception
self.severity = severity
self.player: 'BasePlayer' = player
self.track: 'AudioTrack' = track
self.exception: str = exception
self.severity: str = severity


class TrackEndEvent(Event):
Expand All @@ -109,9 +115,9 @@ class TrackEndEvent(Event):
__slots__ = ('player', 'track', 'reason')

def __init__(self, player, track, reason):
self.player = player
self.track = track
self.reason = reason
self.player: 'BasePlayer' = player
self.track: Optional['AudioTrack'] = track
self.reason: str = reason


class TrackLoadFailedEvent(Event):
Expand All @@ -134,9 +140,9 @@ class TrackLoadFailedEvent(Event):
__slots__ = ('player', 'track', 'original')

def __init__(self, player, track, original):
self.player = player
self.track = track
self.original = original
self.player: 'BasePlayer' = player
self.track: 'DeferredAudioTrack' = track
self.original: Optional[Exception] = original


class QueueEndEvent(Event):
Expand All @@ -152,7 +158,7 @@ class QueueEndEvent(Event):
__slots__ = ('player',)

def __init__(self, player):
self.player = player
self.player: 'BasePlayer' = player


class PlayerUpdateEvent(Event):
Expand All @@ -177,11 +183,11 @@ class PlayerUpdateEvent(Event):
__slots__ = ('player', 'position', 'timestamp', 'connected', 'ping')

def __init__(self, player, raw_state):
self.player = player
self.position = raw_state.get('position')
self.timestamp = raw_state.get('time')
self.connected = raw_state.get('connected')
self.ping = raw_state.get('ping', -1)
self.player: 'BasePlayer' = player
self.position: int = raw_state.get('position')
self.timestamp: int = raw_state.get('time')
self.connected: bool = raw_state.get('connected')
self.ping: int = raw_state.get('ping', -1)


class NodeConnectedEvent(Event):
Expand All @@ -197,7 +203,7 @@ class NodeConnectedEvent(Event):
__slots__ = ('node',)

def __init__(self, node):
self.node = node
self.node: 'Node' = node


class NodeDisconnectedEvent(Event):
Expand All @@ -217,9 +223,9 @@ class NodeDisconnectedEvent(Event):
__slots__ = ('node', 'code', 'reason')

def __init__(self, node, code, reason):
self.node = node
self.code = code
self.reason = reason
self.node: 'Node' = node
self.code: Optional[int] = code
self.reason: Optional[str] = reason


class NodeChangedEvent(Event):
Expand All @@ -240,9 +246,9 @@ class NodeChangedEvent(Event):
__slots__ = ('player', 'old_node', 'new_node')

def __init__(self, player, old_node, new_node):
self.player = player
self.old_node = old_node
self.new_node = new_node
self.player: 'BasePlayer' = player
self.old_node: 'Node' = old_node
self.new_node: 'Node' = new_node


class WebSocketClosedEvent(Event):
Expand All @@ -269,7 +275,7 @@ class WebSocketClosedEvent(Event):
__slots__ = ('player', 'code', 'reason', 'by_remote')

def __init__(self, player, code, reason, by_remote):
self.player = player
self.code = code
self.reason = reason
self.by_remote = by_remote
self.player: 'BasePlayer' = player
self.code: int = code
self.reason: str = reason
self.by_remote: bool = by_remote
25 changes: 12 additions & 13 deletions lavalink/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
"""
from typing import List

from lavalink.models import Plugin

from .events import Event
from .models import BasePlayer, LoadResult, Plugin # noqa: F401
from .stats import Stats
from .websocket import WebSocket

Expand Down Expand Up @@ -66,22 +65,22 @@ def __init__(self, manager, host: str, port: int, password: str,
self._manager = manager
self._ws = WebSocket(self, host, port, password, ssl, resume_key, resume_timeout, reconnect_attempts)

self.host = host
self.port = port
self.password = password
self.ssl = ssl
self.region = region
self.name = name or '{}-{}:{}'.format(self.region, self.host, self.port)
self.filters = filters
self.stats = Stats.empty(self)
self.host: str = host
self.port: int = port
self.password: str = password
self.ssl: bool = ssl
self.region: str = region
self.name: str = name or '{}-{}:{}'.format(self.region, self.host, self.port)
self.filters: bool = filters
self.stats: Stats = Stats.empty(self)

@property
def available(self) -> bool:
""" Returns whether the node is available for requests. """
return self._ws.connected

@property
def _original_players(self):
def _original_players(self) -> List[BasePlayer]:
"""
Returns a list of players that were assigned to this node, but were moved due to failover etc.
Expand All @@ -92,7 +91,7 @@ def _original_players(self):
return [p for p in self._lavalink.player_manager.values() if p._original_node == self]

@property
def players(self):
def players(self) -> List[BasePlayer]:
"""
Returns a list of all players on this node.
Expand Down Expand Up @@ -122,7 +121,7 @@ async def destroy(self):
"""
await self._ws.destroy()

async def get_tracks(self, query: str, check_local: bool = False):
async def get_tracks(self, query: str, check_local: bool = False) -> LoadResult:
"""|coro|
Retrieves a list of results pertaining to the provided query.
Expand Down
9 changes: 5 additions & 4 deletions lavalink/nodemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SOFTWARE.
"""
import logging
from typing import List, Optional

from .events import NodeConnectedEvent, NodeDisconnectedEvent
from .node import Node
Expand Down Expand Up @@ -63,7 +64,7 @@ def __iter__(self):
yield n

@property
def available_nodes(self):
def available_nodes(self) -> List[Node]:
""" Returns a list of available nodes. """
return [n for n in self.nodes if n.available]

Expand Down Expand Up @@ -123,7 +124,7 @@ def remove_node(self, node: Node):
self.nodes.remove(node)
_log.info('Removed node \'%s\'', node.name)

def get_nodes_by_region(self, region_key: str):
def get_nodes_by_region(self, region_key: str) -> List[Node]:
"""
Get a list of nodes by their region.
This does not account for node availability, so the nodes returned
Expand All @@ -143,7 +144,7 @@ def get_nodes_by_region(self, region_key: str):
"""
return [n for n in self.nodes if n.region == region_key]

def get_region(self, endpoint: str):
def get_region(self, endpoint: str) -> str:
"""
Returns a Lavalink.py-friendly region from a Discord voice server address.
Expand All @@ -170,7 +171,7 @@ def get_region(self, endpoint: str):

return None

def find_ideal_node(self, region: str = None):
def find_ideal_node(self, region: str = None) -> Optional[Node]:
"""
Finds the best (least used) node in the given region, if applicable.
Expand Down
9 changes: 5 additions & 4 deletions lavalink/playermanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SOFTWARE.
"""
import logging
from typing import Callable, Dict, Iterator

from .errors import NodeError
from .models import BasePlayer
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, lavalink, player):

self._lavalink = lavalink
self._player_cls = player
self.players = {}
self.players: Dict[int, BasePlayer] = {}

def __len__(self):
return len(self.players)
Expand All @@ -62,18 +63,18 @@ def __iter__(self):
for guild_id, player in self.players.items():
yield guild_id, player

def values(self):
def values(self) -> Iterator[BasePlayer]:
""" Returns an iterator that yields only values. """
for player in self.players.values():
yield player

def find_all(self, predicate=None):
def find_all(self, predicate: Callable[[BasePlayer], bool] = None):
"""
Returns a list of players that match the given predicate.
Parameters
----------
predicate: Optional[:class:`function`]
predicate: Optional[Callable[[:class:BasePlayer], bool]]
A predicate to return specific players. Defaults to ``None``.
Returns
Expand Down

0 comments on commit 60b6eab

Please sign in to comment.