diff --git a/.flake8 b/.flake8 index fa030e74..73c76f23 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,2 @@ [flake8] -max-line-length = 300 +max-line-length = 150 diff --git a/.gitignore b/.gitignore index 738d0097..721220cd 100644 --- a/.gitignore +++ b/.gitignore @@ -100,5 +100,7 @@ ENV/ # mypy .mypy_cache/ +\.DS_store + \.idea/ \.vscode/ diff --git a/README.md b/README.md index 758d14eb..fa227d02 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ # Lavalink.py -[![Python](https://img.shields.io/badge/Python-3.5%20%7C%203.6%20%7C%203.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://www.python.org) [![Build Status](https://travis-ci.com/devoxin/Lavalink.py.svg?branch=master)](https://travis-ci.com/Devoxin/Lavalink.py) [![Codacy Badge](https://app.codacy.com/project/badge/Grade/428eebed5a2e467fb038eacfa1d92e62)](https://www.codacy.com/gh/Devoxin/Lavalink.py/dashboard?utm_source=github.com&utm_medium=referral&utm_content=Devoxin/Lavalink.py&utm_campaign=Badge_Grade) [![License](https://img.shields.io/github/license/Devoxin/Lavalink.py.svg)](LICENSE) [![Documentation Status](https://readthedocs.org/projects/lavalink/badge/?version=latest)](https://lavalink.readthedocs.io/en/latest/?badge=latest) +[![Python](https://img.shields.io/badge/Python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://www.python.org) [![Build Status](https://travis-ci.com/devoxin/Lavalink.py.svg?branch=master)](https://travis-ci.com/Devoxin/Lavalink.py) [![Codacy Badge](https://app.codacy.com/project/badge/Grade/428eebed5a2e467fb038eacfa1d92e62)](https://www.codacy.com/gh/Devoxin/Lavalink.py/dashboard?utm_source=github.com&utm_medium=referral&utm_content=Devoxin/Lavalink.py&utm_campaign=Badge_Grade) [![License](https://img.shields.io/github/license/Devoxin/Lavalink.py.svg)](LICENSE) [![Documentation Status](https://readthedocs.org/projects/lavalink/badge/?version=latest)](https://lavalink.readthedocs.io/en/latest/?badge=latest) Lavalink.py is a wrapper for [Lavalink] which abstracts away most of the code necessary to use Lavalink, allowing for easier integration into your projects, while still promising full API coverage and powerful tools to get the most out of it. diff --git a/docs/conf.py b/docs/conf.py index 648f62f5..06963abe 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ # -- Project information ----------------------------------------------------- project = 'Lavalink.py' -copyright = '2022, Devoxin' +copyright = '2023, Devoxin' author = 'Devoxin' master_doc = 'index' diff --git a/docs/lavalink.rst b/docs/lavalink.rst index 0254641a..80cf83c7 100644 --- a/docs/lavalink.rst +++ b/docs/lavalink.rst @@ -7,7 +7,19 @@ Documentation .. autofunction:: listener -.. autofunction:: add_event_hook +ABC +--- +.. autoclass:: BasePlayer + :members: + +.. autoclass:: DeferredAudioTrack + :members: + +.. autoclass:: Source + :members: + +.. autoclass:: Filter + :members: Client ------ @@ -16,12 +28,20 @@ Client Errors ------ -.. autoclass:: NodeError +.. autoclass:: ClientError .. autoclass:: AuthenticationError .. autoclass:: InvalidTrack +.. autoclass:: LoadError + +.. autoclass:: RequestError + :members: + +.. autoclass:: PlayerErrorEvent + :members: + Events ------ All Events are derived from :class:`Event` @@ -59,16 +79,19 @@ All Events are derived from :class:`Event` .. autoclass:: NodeChangedEvent :members: +.. autoclass:: NodeReadyEvent + :members: + .. autoclass:: WebSocketClosedEvent :members: +.. autoclass:: IncomingWebSocketMessage + :members: + Filters ------- **All** custom filters must derive from :class:`Filter` -.. autoclass:: Filter - :members: - .. autoclass:: Equalizer :members: @@ -96,32 +119,37 @@ Filters .. autoclass:: Volume :members: -Models +.. autoclass:: Distortion + :members: + +Player ------ **All** custom players must derive from :class:`BasePlayer` -.. autoclass:: AudioTrack +.. autoclass:: DefaultPlayer :members: -.. autoclass:: DeferredAudioTrack +Server +------ +.. autoclass:: AudioTrack :members: -.. autoenum:: LoadType +.. autoenum:: EndReason :members: -.. autoclass:: PlaylistInfo +.. autoenum:: LoadType :members: -.. autoclass:: LoadResult +.. autoenum:: Severity :members: -.. autoclass:: Source +.. autoclass:: PlaylistInfo :members: -.. autoclass:: BasePlayer +.. autoclass:: LoadResultError :members: -.. autoclass:: DefaultPlayer +.. autoclass:: LoadResult :members: .. autoclass:: Plugin @@ -159,3 +187,5 @@ Utilities .. autofunction:: parse_time .. autofunction:: decode_track + +.. autofunction:: encode_track diff --git a/docs/license.rst b/docs/license.rst index cff6314b..61fb6ed1 100644 --- a/docs/license.rst +++ b/docs/license.rst @@ -4,7 +4,7 @@ Licensed using the `MIT license `_. MIT License - Copyright (c) 2022 Devoxin + Copyright (c) 2023 Devoxin Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/examples/hikari_music.py b/examples/hikari_music.py index b294a9c8..e889b3b2 100644 --- a/examples/hikari_music.py +++ b/examples/hikari_music.py @@ -83,7 +83,7 @@ async def _join(ctx: lightbulb.Context): # user not in voice channel if not voice_state: return - + channel_id = voice_state[0].channel_id # channel user is connected to plugin.bot.d.lavalink.player_manager.create(guild_id=ctx.guild_id) await plugin.bot.update_voice_state(ctx.guild_id, channel_id, self_deaf=True) diff --git a/examples/music.py b/examples/music.py index 3a2e3db3..b74b7543 100644 --- a/examples/music.py +++ b/examples/music.py @@ -11,12 +11,15 @@ import discord import lavalink from discord.ext import commands +from lavalink.events import TrackStartEvent, QueueEndEvent +from lavalink.errors import ClientError from lavalink.filters import LowPass +from lavalink.server import LoadType url_rx = re.compile(r'https?://(?:www\.)?.+') -class LavalinkVoiceClient(discord.VoiceClient): +class LavalinkVoiceClient(discord.VoiceProtocol): """ This is the preferred way to handle external voice sending This client will be created via a cls in the connect method of the channel @@ -27,19 +30,19 @@ class LavalinkVoiceClient(discord.VoiceClient): def __init__(self, client: discord.Client, channel: discord.abc.Connectable): self.client = client self.channel = channel - # ensure a client already exists - if hasattr(self.client, 'lavalink'): - self.lavalink = self.client.lavalink - else: + self.guild_id = channel.guild.id + self._destroyed = False + + if not hasattr(self.client, 'lavalink'): + # Instantiate a client if one doesn't exist. + # We store it in `self.client` so that it may persist across cog reloads, + # however this is not mandatory. self.client.lavalink = lavalink.Client(client.user.id) - self.client.lavalink.add_node( - 'localhost', - 2333, - 'youshallnotpass', - 'us', - 'default-node' - ) - self.lavalink = self.client.lavalink + self.client.lavalink.add_node(host='localhost', port=2333, password='youshallnotpass', + region='us', name='default-node') + + # Create a shortcut to the Lavalink client here. + self.lavalink = self.client.lavalink async def on_voice_server_update(self, data): # the data needs to be transformed before being handed down to @@ -51,12 +54,21 @@ async def on_voice_server_update(self, data): await self.lavalink.voice_update_handler(lavalink_data) async def on_voice_state_update(self, data): + channel_id = data['channel_id'] + + if not channel_id: + await self._destroy() + return + + self.channel = self.client.get_channel(int(channel_id)) + # the data needs to be transformed before being handed down to # voice_update_handler lavalink_data = { 't': 'VOICE_STATE_UPDATE', 'd': data } + await self.lavalink.voice_update_handler(lavalink_data) async def connect(self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False) -> None: @@ -86,34 +98,44 @@ async def disconnect(self, *, force: bool = False) -> None: # this must be done because the on_voice_state_update that would set channel_id # to None doesn't get dispatched after the disconnect player.channel_id = None + await self._destroy() + + async def _destroy(self): self.cleanup() + if self._destroyed: + # Idempotency handling, if `disconnect()` is called, the changed voice state + # could cause this to run a second time. + return + + self._destroyed = True + + try: + await self.lavalink.player_manager.destroy(self.guild_id) + except ClientError: + pass + class Music(commands.Cog): def __init__(self, bot): self.bot = bot - if not hasattr(bot, 'lavalink'): # This ensures the client isn't overwritten during cog reloads. + if not hasattr(bot, 'lavalink'): bot.lavalink = lavalink.Client(bot.user.id) - bot.lavalink.add_node('127.0.0.1', 2333, 'youshallnotpass', 'eu', 'default-node') # Host, Port, Password, Region, Name + bot.lavalink.add_node(host='localhost', port=2333, password='youshallnotpass', + region='us', name='default-node') - lavalink.add_event_hook(self.track_hook) + self.lavalink: lavalink.Client = bot.lavalink + self.lavalink.add_event_hooks(self) def cog_unload(self): - """ Cog unload handler. This removes any event hooks that were registered. """ - self.bot.lavalink._event_hooks.clear() - - async def cog_before_invoke(self, ctx): - """ Command before-invoke handler. """ - guild_check = ctx.guild is not None - # This is essentially the same as `@commands.guild_only()` - # except it saves us repeating ourselves (and also a few lines). - - if guild_check: - await self.ensure_voice(ctx) - # Ensure that the bot and command author share a mutual voicechannel. + """ + This will remove any registered event hooks when the cog is unloaded. + They will subsequently be registered again once the cog is loaded. - return guild_check + This effectively allows for event handlers to be updated when the cog is reloaded. + """ + self.lavalink._event_hooks.clear() async def cog_command_error(self, ctx, error): if isinstance(error, commands.CommandInvokeError): @@ -123,9 +145,17 @@ async def cog_command_error(self, ctx, error): # which contain a reason string, such as "Join a voicechannel" etc. You can modify the above # if you want to do things differently. - async def ensure_voice(self, ctx): - """ This check ensures that the bot and command author are in the same voicechannel. """ - player = self.bot.lavalink.player_manager.create(ctx.guild.id) + async def create_player(ctx: commands.Context): + """ + A check that is invoked before any commands marked with `@commands.check(create_player)` can run. + + This function will try to create a player for the guild associated with this Context, or raise + an error which will be relayed to the user if one cannot be created. + """ + if ctx.guild is None: + raise commands.NoPrivateMessage() + + player = ctx.bot.lavalink.player_manager.create(ctx.guild.id) # Create returns a player if one exists, otherwise creates. # This line is important because it ensures that a player always exists for a guild. @@ -136,38 +166,64 @@ async def ensure_voice(self, ctx): # Commands such as volume/skip etc don't require the bot to be in a voicechannel so don't need listing here. should_connect = ctx.command.name in ('play',) + voice_client = ctx.voice_client + if not ctx.author.voice or not ctx.author.voice.channel: - # Our cog_command_error handler catches this and sends it to the voicechannel. - # Exceptions allow us to "short-circuit" command invocation via checks so the - # execution state of the command goes no further. + # Check if we're in a voice channel. If we are, tell the user to join our voice channel. + if voice_client is not None: + raise commands.CommandInvokeError('You need to join my voice channel first.') + + # Otherwise, tell them to join any voice channel to begin playing music. raise commands.CommandInvokeError('Join a voicechannel first.') - v_client = ctx.voice_client - if not v_client: + voice_channel = ctx.author.voice.channel + + if voice_client is None: if not should_connect: - raise commands.CommandInvokeError('Not connected.') + raise commands.CommandInvokeError("I'm not playing music.") - permissions = ctx.author.voice.channel.permissions_for(ctx.me) + permissions = voice_channel.permissions_for(ctx.me) - if not permissions.connect or not permissions.speak: # Check user limit too? + if not permissions.connect or not permissions.speak: raise commands.CommandInvokeError('I need the `CONNECT` and `SPEAK` permissions.') + if voice_channel.user_limit > 0: + # A limit of 0 means no limit. Anything higher means that there is a member limit which we need to check. + # If it's full, and we don't have "move members" permissions, then we cannot join it. + if len(voice_channel.members) >= voice_channel.user_limit and not ctx.me.guild_permissions.move_members: + raise commands.CommandInvokeError('Your voice channel is full!') + player.store('channel', ctx.channel.id) await ctx.author.voice.channel.connect(cls=LavalinkVoiceClient) - else: - if v_client.channel.id != ctx.author.voice.channel.id: - raise commands.CommandInvokeError('You need to be in my voicechannel.') - - async def track_hook(self, event): - if isinstance(event, lavalink.events.QueueEndEvent): - # When this track_hook receives a "QueueEndEvent" from lavalink.py - # it indicates that there are no tracks left in the player's queue. - # To save on resources, we can tell the bot to disconnect from the voicechannel. - guild_id = event.player.guild_id - guild = self.bot.get_guild(guild_id) + elif voice_client.channel.id != voice_channel.id: + raise commands.CommandInvokeError('You need to be in my voicechannel.') + + return True + + @lavalink.listener(TrackStartEvent) + async def on_track_start(self, event: TrackStartEvent): + guild_id = event.player.guild_id + channel_id = event.player.fetch('channel') + guild = self.bot.get_guild(guild_id) + + if not guild: + return await self.lavalink.player_manager.destroy(guild_id) + + channel = guild.get_channel(channel_id) + + if channel: + await channel.send('Now playing: {} by {}'.format(event.track.title, event.track.author)) + + @lavalink.listener(QueueEndEvent) + async def on_queue_end(self, event: QueueEndEvent): + guild_id = event.player.guild_id + guild = self.bot.get_guild(guild_id) + + if guild is not None: await guild.voice_client.disconnect(force=True) @commands.command(aliases=['p']) + @commands.check(create_player) async def play(self, ctx, *, query: str): """ Searches and plays a song from a given query. """ # Get the player for this guild from cache. @@ -183,25 +239,24 @@ async def play(self, ctx, *, query: str): # Get the results for the query from Lavalink. results = await player.node.get_tracks(query) - # Results could be None if Lavalink returns an invalid response (non-JSON/non-200 (OK)). - # Alternatively, results.tracks could be an empty array if the query yielded no tracks. - if not results or not results.tracks: - return await ctx.send('Nothing found!') - embed = discord.Embed(color=discord.Color.blurple()) - # Valid loadTypes are: - # TRACK_LOADED - single video/direct URL) - # PLAYLIST_LOADED - direct URL to playlist) - # SEARCH_RESULT - query prefixed with either ytsearch: or scsearch:. - # NO_MATCHES - query yielded no results - # LOAD_FAILED - most likely, the video encountered an exception during loading. - if results.load_type == 'PLAYLIST_LOADED': + # Valid load_types are: + # TRACK - direct URL to a track + # PLAYLIST - direct URL to playlist + # SEARCH - query prefixed with either "ytsearch:" or "scsearch:". This could possibly be expanded with plugins. + # EMPTY - no results for the query (result.tracks will be empty) + # ERROR - the track encountered an exception during loading + if results.load_type == LoadType.EMPTY: + return await ctx.send("I couldn'\t find any tracks for that query.") + elif results.load_type == LoadType.PLAYLIST: tracks = results.tracks + # Add all of the tracks from the playlist to the queue. for track in tracks: - # Add all of the tracks from the playlist to the queue. - player.add(requester=ctx.author.id, track=track) + # requester isn't necessary but it helps keep track of who queued what. + # You can store additional metadata by passing it as a kwarg (i.e. key=value) + player.add(track=track, requester=ctx.author.id) embed.title = 'Playlist Enqueued!' embed.description = f'{results.playlist_info.name} - {len(tracks)} tracks' @@ -210,7 +265,9 @@ async def play(self, ctx, *, query: str): embed.title = 'Track Enqueued' embed.description = f'[{track.title}]({track.uri})' - player.add(requester=ctx.author.id, track=track) + # requester isn't necessary but it helps keep track of who queued what. + # You can store additional metadata by passing it as a kwarg (i.e. key=value) + player.add(track=track, requester=ctx.author.id) await ctx.send(embed=embed) @@ -220,6 +277,7 @@ async def play(self, ctx, *, query: str): await player.play() @commands.command(aliases=['lp']) + @commands.check(create_player) async def lowpass(self, ctx, strength: float): """ Sets the strength of the low pass filter. """ # Get the player for this guild from cache. @@ -253,18 +311,12 @@ async def lowpass(self, ctx, strength: float): await ctx.send(embed=embed) @commands.command(aliases=['dc']) + @commands.check(create_player) async def disconnect(self, ctx): """ Disconnects the player from the voice channel and clears its queue. """ player = self.bot.lavalink.player_manager.get(ctx.guild.id) - - if not ctx.voice_client: - # We can't disconnect, if we're not connected. - return await ctx.send('Not connected.') - - if not ctx.author.voice or (player.is_connected and ctx.author.voice.channel.id != int(player.channel_id)): - # Abuse prevention. Users not in voice channels, or not in the same voice channel as the bot - # may not disconnect the bot. - return await ctx.send('You\'re not in my voicechannel!') + # The necessary voice channel checks are handled in "create_player." + # We don't need to duplicate code checking them again. # Clear the queue to ensure old tracks don't start playing # when someone else queues something. @@ -273,7 +325,7 @@ async def disconnect(self, ctx): await player.stop() # Disconnect from the voice channel. await ctx.voice_client.disconnect(force=True) - await ctx.send('*⃣ | Disconnected.') + await ctx.send('✳ | Disconnected.') def setup(bot): diff --git a/lavalink/__init__.py b/lavalink/__init__.py index 3f0e7a33..6e6e303b 100644 --- a/lavalink/__init__.py +++ b/lavalink/__init__.py @@ -4,57 +4,32 @@ __author__ = 'Devoxin' __license__ = 'MIT' __copyright__ = 'Copyright 2017-present Devoxin' -__version__ = '4.0.7' +__version__ = '5.0.0' -import inspect import logging import sys +from .abc import BasePlayer, DeferredAudioTrack, Source from .client import Client -from .errors import AuthenticationError, InvalidTrack, LoadError, NodeError -from .events import (Event, NodeChangedEvent, NodeConnectedEvent, - NodeDisconnectedEvent, PlayerUpdateEvent, QueueEndEvent, - TrackEndEvent, TrackExceptionEvent, TrackLoadFailedEvent, +from .errors import (AuthenticationError, ClientError, InvalidTrack, LoadError, + PlayerErrorEvent, RequestError) +from .events import (Event, IncomingWebSocketMessage, NodeChangedEvent, + NodeConnectedEvent, NodeDisconnectedEvent, NodeReadyEvent, + PlayerUpdateEvent, QueueEndEvent, TrackEndEvent, + TrackExceptionEvent, TrackLoadFailedEvent, TrackStartEvent, TrackStuckEvent, WebSocketClosedEvent) -from .filters import (ChannelMix, Equalizer, Filter, Karaoke, LowPass, - Rotation, Timescale, Tremolo, Vibrato, Volume) -from .models import (AudioTrack, BasePlayer, DefaultPlayer, DeferredAudioTrack, - LoadResult, LoadType, PlaylistInfo, Plugin, Source) +from .filters import (ChannelMix, Distortion, Equalizer, Filter, Karaoke, + LowPass, Rotation, Timescale, Tremolo, Vibrato, Volume) from .node import Node from .nodemanager import NodeManager +from .player import DefaultPlayer from .playermanager import PlayerManager +from .server import (AudioTrack, EndReason, LoadResult, LoadResultError, + LoadType, PlaylistInfo, Plugin, Severity) from .stats import Penalty, Stats from .utils import (decode_track, encode_track, format_time, parse_time, timestamp_to_millis) -from .websocket import WebSocket - - -def enable_debug_logging(submodule: str = None): - """ - Sets up a logger to stdout. This solely exists to make things easier for - end-users who want to debug issues with Lavalink.py. - - Parameters - ---------- - module: :class:`str` - The module to enable logging for. ``None`` to enable debug logging for - the entirety of Lavalink.py. - - Example: ``lavalink.enable_debug_logging('websocket')`` - """ - module_name = 'lavalink.{}'.format(submodule) if submodule else 'lavalink' - log = logging.getLogger(module_name) - - fmt = logging.Formatter( - '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s', # lavalink.py - datefmt="%H:%M:%S" - ) - - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(fmt) - log.addHandler(handler) - log.setLevel(logging.DEBUG) def listener(*events: Event): @@ -90,36 +65,3 @@ def wrapper(func): setattr(func, '_lavalink_events', events) return func return wrapper - - -def add_event_hook(*hooks, event: Event = None): - """ - Adds an event hook to be dispatched on an event. - - Note - ---- - Track event dispatch order is not guaranteed! - For example, this means you could receive a :class:`TrackStartEvent` before you receive a - :class:`TrackEndEvent` when executing operations such as ``skip()``. - - Parameters - ---------- - hooks: :class:`function` - The hooks to register for the given event type. - If ``event`` parameter is left empty, then it will run when any event is dispatched. - event: :class:`Event` - The event the hook belongs to. This will dispatch when that specific event is - dispatched. Defaults to ``None`` which means the hook is dispatched on all events. - """ - if event is not None and Event not in event.__bases__: - raise TypeError('Event parameter is not of type Event or None') - - event_name = event.__name__ if event is not None else 'Generic' - event_hooks = Client._event_hooks[event_name] - - for hook in hooks: - if not callable(hook) or not inspect.iscoroutinefunction(hook): - raise TypeError('Hook is not callable or a coroutine') - - if hook not in event_hooks: - event_hooks.append(hook) diff --git a/lavalink/__main__.py b/lavalink/__main__.py index 179f0f2d..fe19970a 100644 --- a/lavalink/__main__.py +++ b/lavalink/__main__.py @@ -1,28 +1,149 @@ import os import re import sys +import traceback from subprocess import PIPE, Popen +from time import time +from typing import List, Optional import requests -LAVALINK_BASE_URL = 'https://ci.fredboat.com/repository/download/Lavalink_Build/.lastSuccessful/Lavalink.jar?guest=1&branch=refs/heads/{}' -APPLICATION_BASE_URL = 'https://raw.githubusercontent.com/freyacodes/Lavalink/{}/LavalinkServer/application.yml.example' +RELEASES_URL = 'https://api.github.com/repos/lavalink-devs/Lavalink/releases' +APPLICATION_BASE_URL = 'https://raw.githubusercontent.com/lavalink-devs/Lavalink/{}/LavalinkServer/application.yml.example' +SEMVER_REGEX = re.compile(r'(\d+)\.(\d+)(?:\.(\d+))?(?:-(\w+)(?:\.(\d+)))?') + + +class Release: + def __init__(self, release_json): + self.tag: str = release_json['tag_name'] + self.major_version: int = int(self.tag[0]) + self.prerelease: bool = release_json['prerelease'] + self.draft: bool = release_json['draft'] + + assets = release_json['assets'] + jars = [asset['browser_download_url'] for asset in assets if asset['name'].endswith('.jar')] + self.download_url: Optional[str] = jars[0] if jars else None + + def __str__(self) -> str: + return f'{self.tag} {"[prerelease]" if self.prerelease else ""}' + + def __eq__(self, other): + if not isinstance(other, Release): + return False + + this_match = SEMVER_REGEX.match(self.tag) + other_match = SEMVER_REGEX.match(other.tag) + + if not this_match or not other_match: + raise ValueError('Cannot compare version strings as they do not match the regex pattern') + + this_major, this_minor, this_patch, _, this_build = this_match.groups() + this_version = (int(this_major), int(this_minor), int(this_patch or 0), int(this_build or 0)) + + other_major, other_minor, other_patch, _, other_build = other_match.groups() + other_version = (int(other_major), int(other_minor), int(other_patch or 0), int(other_build or 0)) + + return this_version == other_version + + def __lt__(self, other): + this_match = SEMVER_REGEX.match(self.tag) + + if not this_match: + raise ValueError('Cannot compare version strings as they do not match the regex pattern') + + this_major, this_minor, this_patch, _, this_build = this_match.groups() + this_version = (int(this_major), int(this_minor), int(this_patch or 0), int(this_build or 0)) + + if isinstance(other, str): + parts = list(map(int, other.split('.'))) + + if len(parts) == 1: + other_version = (parts[0], 0, 0, 0) + elif len(parts) == 2: + other_version = (parts[0], parts[1], 0, 0) + elif len(parts) == 3: + other_version = (parts[0], parts[1], parts[2], 0) + else: + raise ValueError('Cannot compare version string with more fields than major.minor.patch') + elif isinstance(other, Release): + other_match = SEMVER_REGEX.match(other.tag) + + if not this_match or not other_match: + raise ValueError('Cannot compare version strings as they do not match the regex pattern') + + other_major, other_minor, other_patch, _, other_build = other_match.groups() + other_version = (int(other_major), int(other_minor), int(other_patch or 0), int(other_build or 0)) + else: + raise TypeError(f'"<" not supported between instances of "{type(self).__name__}" and "{type(other).__name__}"') + + return this_version < other_version + + def __gt__(self, other): + this_match = SEMVER_REGEX.match(self.tag) + + if not this_match: + raise ValueError('Cannot compare version strings as they do not match the regex pattern') + + this_major, this_minor, this_patch, _, this_build = this_match.groups() + this_version = (int(this_major), int(this_minor), int(this_patch or 0), int(this_build or 0)) + + if isinstance(other, str): + parts = list(map(int, other.split('.'))) + + if len(parts) == 1: + other_version = (parts[0], 0, 0, 0) + elif len(parts) == 2: + other_version = (parts[0], parts[1], 0, 0) + elif len(parts) == 3: + other_version = (parts[0], parts[1], parts[2], 0) + else: + raise ValueError('Cannot compare version string with more fields than major.minor.patch') + elif isinstance(other, Release): + other_match = SEMVER_REGEX.match(other.tag) + + if not this_match or not other_match: + raise ValueError('Cannot compare version strings as they do not match the regex pattern') + + other_major, other_minor, other_patch, _, other_build = other_match.groups() + other_version = (int(other_major), int(other_minor), int(other_patch or 0), int(other_build or 0)) + else: + raise TypeError(f'">" not supported between instances of "{type(self).__name__}" and "{type(other).__name__}"') + + return this_version > other_version + + def __ge__(self, other): + return self.__eq__(other) or self.__gt__(other) + + def __le__(self, other): + return self.__eq__(other) or self.__lt__(other) def display_help(): print(""" -download - Downloads the latest (stable) Lavalink jar. - --fetch-dev Fetches the latest Lavalink development jar. +download - Find and download specific Lavalink server versions. --no-overwrite Renames an existing lavalink.jar to lavalink.old.jar config - Downloads a fresh application.yml. --fetch-dev Fetches the latest application.yml from the development branch. --no-overwrite Renames an existing application.yml to application.old.yml. info - Extracts version and build information from an existing Lavalink.jar. - """.strip()) + """.strip(), file=sys.stdout) + + +def format_bytes(length: int) -> str: + sizes = ['B', 'KB', 'MB', 'GB', 'TB'] + unit = 0 + + while length >= 1024 and unit < len(sizes) - 1: + unit += 1 + length /= 1024 + + return f'{length:.2f} {sizes[unit]}' def download(dl_url, path): - res = requests.get(dl_url, stream=True) + res = requests.get(dl_url, stream=True, timeout=15) + + download_begin = round(time() * 1000) def report_progress(cur, tot): bar_len = 32 @@ -30,93 +151,233 @@ def report_progress(cur, tot): filled_len = int(round(bar_len * progress)) percent = round(progress * 100, 2) + elapsed = round(time() * 1000) - download_begin + if elapsed > 0: + correction = 1000 / elapsed + speed = cur * correction + else: + speed = 0 # placeholder until we have enough data to calculate + progress_bar = '█' * filled_len + ' ' * (bar_len - filled_len) - sys.stdout.write('Downloading |%s| %0.2f%% (%d/%d)\r' % (progress_bar, percent, cur, tot)) + sys.stdout.write(f'Downloading |{progress_bar}| {percent:.1f}% ({cur}/{tot}, {format_bytes(speed)}/s)\r') sys.stdout.flush() if cur >= tot: sys.stdout.write('\n') - def read_chunk(f, chunk_size=8192): + def read_chunk(out, chunk_size=8192): total_bytes = int(res.headers['Content-Length'].strip()) current_bytes = 0 for chunk in res.iter_content(chunk_size): - f.write(chunk) + out.write(chunk) current_bytes += len(chunk) report_progress(min(current_bytes, total_bytes), total_bytes) - with open(path, 'wb') as f: - read_chunk(f) + with open(path, 'wb') as out: + read_chunk(out) -def main(): # pylint: disable=too-many-locals,too-many-statements - if len(sys.argv) < 2 or sys.argv[1] == '--help' or sys.argv[1] == 'help' or sys.argv[1] == '?': - display_help() - return +def select_release_unattended(non_draft: List[Release], version_selector: str) -> Release: + matcher = SEMVER_REGEX.match(version_selector) - cwd = os.getcwd() - _, action, *arguments = sys.argv + if matcher: + def exact_version(release: Release): + return release.tag == version_selector - if action == 'download': - target_branch = 'dev' if '--fetch-dev' in arguments else 'master' - dl_url = LAVALINK_BASE_URL.format(target_branch) - dl_path = os.path.join(cwd, 'lavalink.jar') + predicate = exact_version + elif version_selector.startswith('>='): + def gte(release: Release): + return release >= version_selector[2:] - if '--no-overwrite' in arguments and os.path.exists(dl_path): - os.rename(dl_path, os.path.join(cwd, 'lavalink.old.jar')) + predicate = gte + elif version_selector.startswith('<='): + def lte(release: Release): + return release.tag <= version_selector[2:] - download(dl_url, dl_path) - print('Downloaded to {}'.format(dl_path)) - sys.exit(0) - elif action == 'config': - target_branch = 'dev' if '--fetch-dev' in arguments else 'master' - dl_url = APPLICATION_BASE_URL.format(target_branch) - dl_path = os.path.join(cwd, 'application.yml') + predicate = lte + elif version_selector.startswith('>'): + def gt(release: Release): # pylint: disable=C0103 + return release > version_selector[1:] - if '--no-overwrite' in arguments and os.path.exists(dl_path): - os.rename(dl_path, os.path.join(cwd, 'application.old.yml')) + predicate = gt + elif version_selector.startswith('<'): + def lte(release: Release): + return release.tag < version_selector[1:] - download(dl_url, dl_path) - print('Downloaded to {}'.format(dl_path)) - sys.exit(0) - elif action == 'info': - check_names = ['lavalink.jar', 'Lavalink.jar', 'LAVALINK.JAR'] + predicate = lte + elif version_selector.startswith('~='): + minimum = version_selector[2:] + major, minor, _ = minimum.split('.') + maximum = f'{major}.{int(minor) + 1}.0' + + def compatible(release: Release): + return minimum <= release < maximum + + predicate = compatible + else: + # TODO: Support multiple version specifiers (e.g. >=3.7.0,<4.0.0) + raise ValueError('Unsupported version selector') + + selected_release = next((release for release in non_draft if predicate(release)), None) - if arguments: - check_names.extend([arguments[0]]) + if not selected_release: + print('Couldn\'t find a suitable release with the provided version selector.', file=sys.stderr) + sys.exit(1) - file_name = next((fn for fn in check_names if os.path.exists(fn)), None) + print(f'Release selected: {selected_release.tag}', file=sys.stdout) - if not file_name: - print('Unable to display Lavalink server info: No Lavalink file found.') + return selected_release + + +def select_release(non_draft: List[Release]) -> Release: + suitable_releases = [] + + for release in non_draft: + if not release.download_url: + continue + + newest: Optional[Release] = next((sr for sr in suitable_releases if sr.major_version == release.major_version), None) + + if newest: + if newest > release: # GitHub gives newest->oldest releases, so it could be that we iterate over a pre-release before a release. + if newest.prerelease and not release.prerelease: # If that is the case, we check the version against the current non-prerelease + current_non_prerelease: Optional[Release] = next((sr for sr in suitable_releases if sr.major_version == release.major_version + and not sr.prerelease), None) + + if current_non_prerelease and current_non_prerelease > release: + continue + else: + continue + + suitable_releases.append(release) + + if not suitable_releases: + print('No suitable Lavalink releases were found.', file=sys.stdout) + sys.exit(0) # Perhaps this should be an error, however this could also be valid (but very unlikely). + + if len(suitable_releases) > 1: + print('There are multiple Lavalink versions to choose from.\n' + 'They have automatically been filtered based on their version, and whether they are a pre-release.\n\n' + 'Type the number of the release you would like to download.\n', file=sys.stdout) + + for index, release in enumerate(suitable_releases, start=1): + print(f'[{index}] {release}', file=sys.stdout) + + try: + selected = int(input('> ')) - 1 + + if not 0 <= selected <= len(suitable_releases): + raise ValueError + except ValueError: + print('An incorrect selection has been made, cancelling...', file=sys.stderr) sys.exit(1) + else: + selected = 0 - proc = Popen(['java', '-jar', file_name, '--version'], stdout=PIPE, stderr=PIPE, text=True) - stdout, stderr = proc.communicate() + return suitable_releases[selected] - if stderr: - if 'UnsupportedClassVersionError' in stderr: - java_proc = Popen(['java', '-version'], stdout=PIPE, stderr=PIPE, text=True) - j_stdout, j_stderr = java_proc.communicate() - j_ver = re.search(r'java version "([\d._]*)"', j_stdout or j_stderr) - java_version = j_ver.group(1) if j_ver else 'UNKNOWN' - if java_version.startswith('1.8'): - java_version = f'8/{java_version}' +def download_jar(arguments: List[str]): + try: + res = requests.get(RELEASES_URL, timeout=15).json() + except requests.exceptions.JSONDecodeError: + print('Failed to retrieve Lavalink releases', file=sys.stderr) + sys.exit(1) - print('Unable to display Lavalink server info.\nYour Java version is out of date. (Java {})\n\n' - 'Java 11+ is required to run Lavalink.'.format(java_version)) - sys.exit(1) + releases = list(map(Release, res)) + non_draft = [r for r in releases if not r.draft] - print(stderr) + if arguments: + try: + release = select_release_unattended(non_draft, arguments[0]) + except ValueError: + traceback.print_exc(file=sys.stderr) sys.exit(1) - else: - print(stdout.strip()) - sys.exit(0) else: - print('Invalid argument \'{}\'. Use --help to show usage.'.format(action)) + release = select_release(non_draft) + + cwd = os.getcwd() + dl_url = release.download_url + dl_path = os.path.join(cwd, 'lavalink.jar') + + if '--no-overwrite' in arguments and os.path.exists(dl_path): + os.rename(dl_path, os.path.join(cwd, 'lavalink.old.jar')) + + download(dl_url, dl_path) + print(f'Downloaded {release.tag} to {dl_path}', file=sys.stdout) + sys.exit(0) + + +def download_config(arguments: List[str], branch: str): + cwd = os.getcwd() + dl_url = APPLICATION_BASE_URL.format(branch) + dl_path = os.path.join(cwd, 'application.yml') + + if '--no-overwrite' in arguments and os.path.exists(dl_path): + os.rename(dl_path, os.path.join(cwd, 'application.old.yml')) + + download(dl_url, dl_path) + print(f'Downloaded to {dl_path}', file=sys.stdout) + sys.exit(0) + + +def print_info(arguments: List[str]): + check_names = ['lavalink.jar', 'Lavalink.jar', 'LAVALINK.JAR'] + + if arguments: + check_names.extend([arguments[0]]) + + file_name = next((fn for fn in check_names if os.path.exists(fn)), None) + + if not file_name: + print('Unable to display Lavalink server info: No Lavalink file found.', file=sys.stderr) + sys.exit(1) + + proc = Popen(['java', '-jar', file_name, '--version'], stdout=PIPE, stderr=PIPE, text=True) + stdout, stderr = proc.communicate() + + if stderr: + if 'UnsupportedClassVersionError' in stderr: + java_proc = Popen(['java', '-version'], stdout=PIPE, stderr=PIPE, text=True) + j_stdout, j_stderr = java_proc.communicate() + j_ver = re.search(r'java version "([\d._]*)"', j_stdout or j_stderr) + java_version = j_ver.group(1) if j_ver else 'UNKNOWN' + + if java_version.startswith('1.8'): + java_version = f'8/{java_version}' + + print(f'Unable to display Lavalink server info.\nYour Java version is out of date. (Java {java_version})\n\n' + 'Java 11+ is required to run Lavalink.', file=sys.stderr) + sys.exit(1) + + print(stderr, file=sys.stderr) sys.exit(1) + else: + print(stdout.strip(), file=sys.stdout) + sys.exit(0) + + +def main(): + if len(sys.argv) < 2 or sys.argv[1] in ('--help', 'help', '?', '/help'): + display_help() + return + + _, action, *arguments = sys.argv + target_branch = 'dev' if '--fetch-dev' in arguments else 'master' + + try: + if action == 'download': + download_jar(arguments) + elif action == 'config': + download_config(arguments, target_branch) + elif action == 'info': + print_info(arguments) + else: + print(f'Invalid argument \'{action}\'. Use --help to show usage.', file=sys.stderr) + sys.exit(1) + except KeyboardInterrupt: + sys.exit(2) # CTRL-C = SIGINT = 2 if __name__ == '__main__': diff --git a/lavalink/abc.py b/lavalink/abc.py new file mode 100644 index 00000000..a83e804c --- /dev/null +++ b/lavalink/abc.py @@ -0,0 +1,320 @@ +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from .common import MISSING +from .errors import InvalidTrack, LoadError +from .events import TrackLoadFailedEvent +from .server import AudioTrack + +if TYPE_CHECKING: + from .client import Client + from .node import Node + from .player import LoadResult + +_log = logging.getLogger(__name__) + + +class BasePlayer(ABC): + """ + Represents the BasePlayer all players must be inherited from. + + Attributes + ---------- + client: :class:`Client` + The Lavalink client instance. + guild_id: :class:`int` + The guild id of the player. + node: :class:`Node` + The node that the player is connected to. + channel_id: Optional[:class:`int`] + The ID of the voice channel the player is connected to. + This could be None if the player isn't connected. + current: Optional[:class:`AudioTrack`] + The currently playing track. + """ + def __init__(self, guild_id: int, node: 'Node'): + self.client: 'Client' = node.manager.client + self.guild_id: int = guild_id + self.node: 'Node' = node + self.channel_id: Optional[int] = None + self.current: Optional[AudioTrack] = None + + self._next: Optional[AudioTrack] = None + self._internal_id: str = str(guild_id) + self._original_node: Optional['Node'] = None # This is used internally for failover. + self._voice_state = {} + + @abstractmethod + async def _handle_event(self, event): + raise NotImplementedError + + @abstractmethod + async def _update_state(self, state: dict): + raise NotImplementedError + + async def play_track(self, + track: Union[AudioTrack, 'DeferredAudioTrack'], + start_time: int = MISSING, + end_time: int = MISSING, + no_replace: bool = MISSING, + volume: int = MISSING, + pause: bool = MISSING, + **kwargs): + """|coro| + + Plays the given track. + + Parameters + ---------- + track: Union[:class:`AudioTrack`, :class:`DeferredAudioTrack`] + The track to play. + start_time: :class:`int` + The number of milliseconds to offset the track by. + If left unspecified or ``None`` is provided, the track will start from the beginning. + end_time: :class:`int` + The position at which the track should stop playing. + This is an absolute position, so if you want the track to stop at 1 minute, you would pass 60000. + The default behaviour is to play until no more data is received from the remote server. + If left unspecified or ``None`` is provided, the default behaviour is exhibited. + no_replace: :class:`bool` + If set to true, operation will be ignored if a track is already playing or paused. + The default behaviour is to always replace. + If left unspecified or None is provided, the default behaviour is exhibited. + volume: :class:`int` + The initial volume to set. This is useful for changing the volume between tracks etc. + If left unspecified or ``None`` is provided, the volume will remain at its current setting. + pause: :class:`bool` + Whether to immediately pause the track after loading it. + The default behaviour is to never pause. + If left unspecified or ``None`` is provided, the default behaviour is exhibited. + **kwargs: Any + The kwargs to use when playing. You can specify any extra parameters that may be + used by plugins, which offer extra features not supported out-of-the-box by Lavalink.py. + """ + if track is MISSING or not isinstance(track, AudioTrack): + raise ValueError('track must be an instance of an AudioTrack!') + + options = kwargs + + if start_time is not MISSING: + if not isinstance(start_time, int) or 0 > start_time: + raise ValueError('start_time must be an int with a value equal to, or greater than 0') + + options['position'] = start_time + + if end_time is not MISSING: + if not isinstance(end_time, int) or 1 > end_time: + raise ValueError('end_time must be an int with a value equal to, or greater than 1') + + options['end_time'] = end_time + + if no_replace is not MISSING: + if not isinstance(no_replace, bool): + raise TypeError('no_replace must be a bool') + + options['no_replace'] = no_replace + + if volume is not MISSING: + if not isinstance(volume, int): + raise TypeError('volume must be an int') + + self.volume = max(min(volume, 1000), 0) + options['volume'] = self.volume + + if pause is not MISSING: + if not isinstance(pause, bool): + raise TypeError('pause must be a bool') + + options['paused'] = pause + + playable_track = track.track + + if playable_track is None: + if not isinstance(track, DeferredAudioTrack): + raise InvalidTrack('Cannot play the AudioTrack as \'track\' is None, and it is not a DeferredAudioTrack!') + + try: + playable_track = await track.load(self.client) + except LoadError as load_error: + await self.client._dispatch_event(TrackLoadFailedEvent(self, track, load_error)) + + if playable_track is None: # This should only fire when a DeferredAudioTrack fails to yield a base64 track string. + await self.client._dispatch_event(TrackLoadFailedEvent(self, track, None)) + return + + self._next = track + await self.node.update_player(self._internal_id, encoded_track=playable_track, **options) + + def cleanup(self): + pass + + async def destroy(self): + """|coro| + + Destroys the current player instance. + + Shortcut for :func:`PlayerManager.destroy`. + """ + await self.client.player_manager.destroy(self.guild_id) + + async def _voice_server_update(self, data): + self._voice_state.update(endpoint=data['endpoint'], token=data['token']) + + if 'sessionId' not in self._voice_state: # We should've received session_id from a VOICE_STATE_UPDATE before receiving a VOICE_SERVER_UPDATE. + _log.warning('[Player:%s] Missing sessionId, is the client User ID correct?', self.guild_id) + + await self._dispatch_voice_update() + + async def _voice_state_update(self, data): + raw_channel_id = data['channel_id'] + self.channel_id = int(raw_channel_id) if raw_channel_id else None + + if not self.channel_id: # We're disconnecting + self._voice_state.clear() + return + + if data['session_id'] != self._voice_state.get('sessionId'): + self._voice_state.update(sessionId=data['session_id']) + + await self._dispatch_voice_update() + + async def _dispatch_voice_update(self): + if {'sessionId', 'endpoint', 'token'} == self._voice_state.keys(): + await self.node.update_player(self._internal_id, voice_state=self._voice_state) + + @abstractmethod + async def node_unavailable(self): + """|coro| + + Called when a player's node becomes unavailable. + Useful for changing player state before it's moved to another node. + """ + raise NotImplementedError + + @abstractmethod + async def change_node(self, node: 'Node'): + """|coro| + + Called when a node change is requested for the current player instance. + + Parameters + ---------- + node: :class:`Node` + The new node to switch to. + """ + raise NotImplementedError + + +class DeferredAudioTrack(ABC, AudioTrack): + """ + Similar to an :class:`AudioTrack`, however this track only stores metadata up until it's + played, at which time :func:`load` is called to retrieve a base64 string which is then used for playing. + + Note + ---- + For implementation: The ``track`` field need not be populated as this is done later via + the :func:`load` method. You can optionally set ``self.track`` to the result of :func:`load` + during implementation, as a means of caching the base64 string to avoid fetching it again later. + This should serve the purpose of speeding up subsequent play calls in the event of repeat being enabled, + for example. + """ + @abstractmethod + async def load(self, client: 'Client'): + """|coro| + + Retrieves a base64 string that's playable by Lavalink. + For example, you can use this method to search Lavalink for an identical track from other sources, + which you can then use the base64 string of to play the track on Lavalink. + + Parameters + ---------- + client: :class:`Client` + This will be an instance of the Lavalink client 'linked' to this track. + + Returns + ------- + :class:`str` + A Lavalink-compatible base64-encoded string containing track metadata. + """ + raise NotImplementedError + + +class Source(ABC): + def __init__(self, name: str): + self.name: str = name + + def __eq__(self, other): + if self.__class__ is other.__class__: + return self.name == other.name + + return False + + def __hash__(self): + return hash(self.name) + + @abstractmethod + async def load_item(self, client: 'Client', query: str) -> Optional['LoadResult']: + """|coro| + + Loads a track with the given query. + + Parameters + ---------- + client: :class:`Client` + The Lavalink client. This could be useful for performing a Lavalink search + for an identical track from other sources, if needed. + query: :class:`str` + The search query that was provided. + + Returns + ------- + Optional[:class:`LoadResult`] + A LoadResult, or None if there were no matches for the provided query. + """ + raise NotImplementedError + + def __repr__(self): + return f'' + + +class Filter: + """ + A class representing a Lavalink audio filter. + + Parameters + ---------- + values: Union[Dict[str, Any], List[Union[float, int]], float] + The values for this filter. + plugin_filter: :class:`bool` + Whether this filter is part of a Lavalink plugin. Typically, this will be ``True`` + when creating your own filters. + + Attributes + ---------- + values: Union[Dict[str, Any], List[Union[float, int]], float] + The values for this filter. + plugin_filter: :class:`bool` + Whether this filter is part of a Lavalink plugin. + """ + def __init__(self, values: Union[Dict[str, Any], List[Union[float, int]], float], plugin_filter: bool = False): + self.values = values + self.plugin_filter: bool = plugin_filter + + @abstractmethod + def update(self, **kwargs): + """ Updates the internal values to match those provided. """ + raise NotImplementedError + + @abstractmethod + def serialize(self) -> Dict[str, Any]: + """ + Transforms the internal values into a dict matching the structure Lavalink expects. + + Example: + + .. code:: python + + return {"yourCustomFilter": {"gain": 5}} + """ + raise NotImplementedError diff --git a/lavalink/client.py b/lavalink/client.py index 2d711907..1d3327fb 100644 --- a/lavalink/client.py +++ b/lavalink/client.py @@ -22,24 +22,29 @@ SOFTWARE. """ import asyncio +import inspect import itertools import logging import random from collections import defaultdict from inspect import getmembers, ismethod -from typing import Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import aiohttp -from .errors import AuthenticationError, NodeError +from .abc import BasePlayer, Source from .events import Event -from .models import DefaultPlayer, LoadResult, Source from .node import Node from .nodemanager import NodeManager +from .player import DefaultPlayer from .playermanager import PlayerManager +from .server import AudioTrack, LoadResult _log = logging.getLogger(__name__) +PlayerT = TypeVar('PlayerT', bound=BasePlayer) +EventT = TypeVar('EventT', bound=Event) + class Client: """ @@ -49,10 +54,10 @@ class Client: ---------- user_id: Union[:class:`int`, :class:`str`] The user id of the bot. - player: Optional[:class:`BasePlayer`] - The class that should be used for the player. Defaults to ``DefaultPlayer``. + player: Type[:class:`BasePlayer`] + The class that should be used for the player. Defaults to :class:`DefaultPlayer`. Do not change this unless you know what you are doing! - regions: Optional[:class:`dict`] + regions: Optional[Dict[str, Tuple[str]]] A mapping of continent -> Discord RTC regions. The key should be an identifier used when instantiating an node. The values should be a list of RTC regions that will be handled by the associated identifying key. @@ -76,31 +81,56 @@ class Client: Attributes ---------- node_manager: :class:`NodeManager` - Represents the node manager that contains all lavalink nodes. + The node manager, used for storing and managing all registered Lavalink nodes. player_manager: :class:`PlayerManager` - Represents the player manager that contains all the players. + The player manager, used for storing and managing all players. + sources: Set[:class:`Source`] + The custom sources registered to this client. """ - _event_hooks = defaultdict(list) + __slots__ = ('_session', '_user_id', '_event_hooks', 'node_manager', 'player_manager', 'sources') - def __init__(self, user_id: Union[int, str], player=DefaultPlayer, regions: dict = None, - connect_back: bool = False): + def __init__(self, user_id: Union[int, str], player: Type[PlayerT] = DefaultPlayer, + regions: Optional[Dict[str, Tuple[str]]] = None, connect_back: bool = False): if not isinstance(user_id, (str, int)) or isinstance(user_id, bool): # bool has special handling because it subclasses `int`, so will return True for the first isinstance check. - raise TypeError('user_id must be either an int or str (not {}). If the type is None, ' + raise TypeError(f'user_id must be either an int or str (not {type(user_id).__name__}). If the type is None, ' 'ensure your bot has fired "on_ready" before instantiating ' - 'the Lavalink client. Alternatively, you can hardcode your user ID.' - .format(user_id)) + 'the Lavalink client. Alternatively, you can hardcode your user ID.') - self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) - self._user_id: str = str(user_id) - self._connect_back: bool = connect_back - self.node_manager: NodeManager = NodeManager(self, regions) + self._session: aiohttp.ClientSession = aiohttp.ClientSession() + self._user_id: str = int(user_id) + self._event_hooks = defaultdict(list) + self.node_manager: NodeManager = NodeManager(self, regions, connect_back) self.player_manager: PlayerManager = PlayerManager(self, player) self.sources: Set[Source] = set() - def add_event_hook(self, hook): + @property + def nodes(self) -> List[Node]: + """ + Convenience shortcut for :attr:`NodeManager.nodes`. + """ + return self.node_manager.nodes + + @property + def players(self) -> Dict[int, BasePlayer]: + """ + Convenience shortcut for :attr:`PlayerManager.players`. + """ + return self.player_manager.players + + async def close(self): + """|coro| + + Closes all active connections and frees any resources in use. + """ + for node in self.node_manager: + await node.destroy() + + await self._session.close() + + def add_event_hook(self, *hooks, event: Optional[Type[EventT]] = None): """ - Registers a function to recieve and process Lavalink events. + Adds one or more event hooks to be dispatched on an event. Note ---- @@ -110,13 +140,27 @@ def add_event_hook(self, hook): Parameters ---------- - hook: :class:`function` - The function to register. + hooks: :class:`function` + The hooks to register for the given event type. + If ``event`` parameter is left empty, then it will run when any event is dispatched. + event: Optional[Type[:class:`Event`]] + The event the hooks belong to. They will be called when that specific event type is + dispatched. Defaults to ``None`` which means the hook is dispatched on all events. """ - if hook not in self._event_hooks['Generic']: - self._event_hooks['Generic'].append(hook) + if event is not None and Event not in event.__bases__: + raise TypeError('Event parameter is not of type Event or None') + + event_name = event.__name__ if event is not None else 'Generic' + event_hooks = self._event_hooks[event_name] + + for hook in hooks: + if not callable(hook) or not inspect.iscoroutinefunction(hook): + raise TypeError('Hook is not callable or a coroutine') - def add_event_hooks(self, cls): + if hook not in event_hooks: + event_hooks.append(hook) + + def add_event_hooks(self, cls: Any): # TODO: I don't think Any is the correct type here... """ Scans the provided class ``cls`` for functions decorated with :func:`listener`, and sets them up to process Lavalink events. @@ -137,8 +181,8 @@ def add_event_hooks(self, cls): Parameters ---------- - cls: :class:`Class` - An instance of a class. + cls: Any + An instance of a class containing event hook methods. """ methods = getmembers(cls, predicate=lambda meth: hasattr(meth, '__name__') and not meth.__name__.startswith('_') and ismethod(meth) @@ -164,7 +208,7 @@ def register_source(self, source: Source): The source to register. """ if not isinstance(source, Source): - raise TypeError('source must inherit from Source!') + raise TypeError(f'Class \'{type(source).__name__}\' must inherit Source!') self.sources.add(source) @@ -185,10 +229,11 @@ def get_source(self, source_name: str) -> Optional[Source]: """ return next((source for source in self.sources if source.name == source_name), None) - def add_node(self, host: str, port: int, password: str, region: str, - resume_key: str = None, resume_timeout: int = 60, name: str = None, - reconnect_attempts: int = 3, filters: bool = True, ssl: bool = False): + def add_node(self, host: str, port: int, password: str, region: str, name: str = None, + ssl: bool = False, session_id: Optional[str] = None) -> Node: """ + Shortcut for :meth:`NodeManager.add_node`. + Adds a node to Lavalink's node manager. Parameters @@ -201,29 +246,47 @@ def add_node(self, host: str, port: int, password: str, region: str, The password used for authentication. region: :class:`str` The region to assign this node to. - resume_key: Optional[:class:`str`] - A resume key used for resuming a session upon re-establishing a WebSocket connection to Lavalink. - Defaults to ``None``. - resume_timeout: Optional[:class:`int`] - How long the node should wait for a connection while disconnected before clearing all players. - Defaults to ``60``. name: Optional[:class:`str`] An identifier for the node that will show in logs. Defaults to ``None``. - reconnect_attempts: Optional[:class:`int`] - The amount of times connection with the node will be reattempted before giving up. - Set to `-1` for infinite. Defaults to ``3``. - filters: Optional[:class:`bool`] - Whether to use the new ``filters`` op instead of the ``equalizer`` op. - If you're running a build without filter support, set this to ``False``. ssl: Optional[:class:`bool`] Whether to use SSL for the node. SSL will use ``wss`` and ``https``, instead of ``ws`` and ``http``, respectively. Your node should support SSL if you intend to enable this, either via reverse proxy or other methods. Only enable this if you know what you're doing. + session_id: Optional[:class:`str`] + The ID of the session to resume. Defaults to ``None``. + Only specify this if you have the ID of the session you want to resume. + + Returns + ------- + :class:`Node` + The created Node instance. + """ + return self.node_manager.add_node(host, port, password, region, name, ssl, session_id) + + async def get_local_tracks(self, query: str) -> LoadResult: + """|coro| + + Searches :attr:`sources` registered to this client for the given query. + + Parameters + ---------- + query: :class:`str` + The query to perform a search for. + + Returns + ------- + :class:`LoadResult` """ - self.node_manager.add_node(host, port, password, region, resume_key, resume_timeout, name, reconnect_attempts, - filters, ssl) + for source in self.sources: + load_result = await source.load_item(self, query) + + if load_result: + return load_result - async def get_tracks(self, query: str, node: Node = None, check_local: bool = False) -> LoadResult: + return LoadResult.empty() + + async def get_tracks(self, query: str, node: Optional[Node] = None, + check_local: bool = False) -> LoadResult: """|coro| Retrieves a list of results pertaining to the provided query. @@ -257,15 +320,10 @@ async def get_tracks(self, query: str, node: Node = None, check_local: bool = Fa if load_result: return load_result - if not self.node_manager.available_nodes: - raise NodeError('No available nodes!') - node = node or random.choice(self.node_manager.available_nodes) - res = await self._get_request('{}/loadtracks'.format(node.http_uri), - params={'identifier': query}, - headers={'Authorization': node.password}) - return LoadResult.from_dict(res) + node = node or random.choice(self.node_manager.nodes) + return await node.get_tracks(query) - async def decode_track(self, track: str, node: Node = None): + async def decode_track(self, track: str, node: Optional[Node] = None) -> AudioTrack: """|coro| Decodes a base64-encoded track string into a dict. @@ -279,20 +337,15 @@ async def decode_track(self, track: str, node: Node = None): Returns ------- - :class:`dict` - A dict representing the track's information. + :class:`AudioTrack` """ - if not self.node_manager.available_nodes: - raise NodeError('No available nodes!') - node = node or random.choice(self.node_manager.available_nodes) - return await self._get_request('{}/decodetrack'.format(node.http_uri), - params={'track': track}, - headers={'Authorization': node.password}) - - async def decode_tracks(self, tracks: list, node: Node = None): + node = node or random.choice(self.node_manager.nodes) + return await node.decode_track(track) + + async def decode_tracks(self, tracks: List[str], node: Optional[Node] = None) -> List[AudioTrack]: """|coro| - Decodes a list of base64-encoded track strings into a dict. + Decodes a list of base64-encoded track strings into ``AudioTrack``s. Parameters ---------- @@ -303,18 +356,13 @@ async def decode_tracks(self, tracks: list, node: Node = None): Returns ------- - List[:class:`dict`] - A list of dicts representing track information. + List[:class:`AudioTrack`] + A list of decoded ``AudioTrack``s. """ - if not self.node_manager.available_nodes: - raise NodeError('No available nodes!') - node = node or random.choice(self.node_manager.available_nodes) - - return await self._post_request('{}/decodetracks'.format(node.http_uri), - json=tracks, - headers={'Authorization': node.password}) + node = node or random.choice(self.node_manager.nodes) + return await node.decode_tracks(tracks) - async def voice_update_handler(self, data): + async def voice_update_handler(self, data: Dict[str, Any]): """|coro| This function intercepts websocket data from your Discord library and @@ -329,7 +377,7 @@ async def voice_update_handler(self, data): Parameters ---------- - data: :class:`dict` + data: Dict[str, Any] The payload received from Discord. """ if not data or 't' not in data: @@ -342,7 +390,7 @@ async def voice_update_handler(self, data): if player: await player._voice_server_update(data['d']) elif data['t'] == 'VOICE_STATE_UPDATE': - if int(data['d']['user_id']) != int(self._user_id): + if int(data['d']['user_id']) != self._user_id: return guild_id = int(data['d']['guild_id']) @@ -351,30 +399,11 @@ async def voice_update_handler(self, data): if player: await player._voice_state_update(data['d']) - async def _get_request(self, url, **kwargs): - async with self._session.get(url, **kwargs) as res: - if res.status == 401 or res.status == 403: - raise AuthenticationError - - if res.status == 200: - return await res.json() - - raise NodeError('An invalid response was received from the node: code={}, body={}' - .format(res.status, await res.text())) - - async def _post_request(self, url, **kwargs): - async with self._session.post(url, **kwargs) as res: - if res.status == 401 or res.status == 403: - raise AuthenticationError - - if 'json' in kwargs: - if res.status == 200: - return await res.json() - - raise NodeError('An invalid response was received from the node: code={}, body={}' - .format(res.status, await res.text())) - - return res.status == 204 + def has_listeners(self, event: Type[Event]) -> bool: + """ + Check whether the client has any listeners for a specific event type. + """ + return len(self._event_hooks['Generic']) > 0 or len(self._event_hooks[event.__name__]) > 0 async def _dispatch_event(self, event: Event): """|coro| @@ -386,8 +415,8 @@ async def _dispatch_event(self, event: Event): event: :class:`Event` The event to dispatch to the hooks. """ - generic_hooks = Client._event_hooks['Generic'] - targeted_hooks = Client._event_hooks[type(event).__name__] + generic_hooks = self._event_hooks['Generic'] + targeted_hooks = self._event_hooks[type(event).__name__] if not generic_hooks and not targeted_hooks: return @@ -397,9 +426,6 @@ async def _hook_wrapper(hook, event): await hook(event) except: # noqa: E722 pylint: disable=bare-except _log.exception('Event hook \'%s\' encountered an exception!', hook.__name__) - # According to https://stackoverflow.com/questions/5191830/how-do-i-log-a-python-error-with-debug-information - # the exception information should automatically be attached here. We're just including a message for - # clarity. tasks = [_hook_wrapper(hook, event) for hook in itertools.chain(generic_hooks, targeted_hooks)] await asyncio.gather(*tasks) @@ -407,4 +433,4 @@ async def _hook_wrapper(hook, event): _log.debug('Dispatched \'%s\' to all registered hooks', type(event).__name__) def __repr__(self): - return ''.format(self._user_id, len(self.node_manager), len(self.player_manager)) + return f'' diff --git a/lavalink/common.py b/lavalink/common.py new file mode 100644 index 00000000..0c7c80bf --- /dev/null +++ b/lavalink/common.py @@ -0,0 +1,24 @@ +""" +MIT License + +Copyright (c) 2017-present Devoxin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +MISSING = object() diff --git a/lavalink/datarw.py b/lavalink/dataio.py similarity index 83% rename from lavalink/datarw.py rename to lavalink/dataio.py index 840c07bf..395d3a7d 100644 --- a/lavalink/datarw.py +++ b/lavalink/dataio.py @@ -24,41 +24,45 @@ import struct from base64 import b64decode from io import BytesIO +from typing import Optional from .utfm_codec import read_utfm class DataReader: - def __init__(self, ts): - self._buf = BytesIO(b64decode(ts)) + def __init__(self, base64_str: str): + self._buf: BytesIO = BytesIO(b64decode(base64_str)) def _read(self, count): return self._buf.read(count) - def read_byte(self): + def read_byte(self) -> int: return self._read(1) - def read_boolean(self): + def read_boolean(self) -> bool: result, = struct.unpack('B', self.read_byte()) return result != 0 - def read_unsigned_short(self): + def read_unsigned_short(self) -> int: result, = struct.unpack('>H', self._read(2)) return result - def read_int(self): + def read_int(self) -> int: result, = struct.unpack('>i', self._read(4)) return result - def read_long(self): + def read_long(self) -> int: result, = struct.unpack('>Q', self._read(8)) return result - def read_utf(self): + def read_nullable_utf(self) -> Optional[str]: + return self.read_utf().decode() if self.read_boolean() else None + + def read_utf(self) -> bytes: text_length = self.read_unsigned_short() return self._read(text_length) - def read_utfm(self): + def read_utfm(self) -> str: text_length = self.read_unsigned_short() utf_string = self._read(text_length) return read_utfm(text_length, utf_string) @@ -90,6 +94,12 @@ def write_long(self, long_value): enc = struct.pack('>Q', long_value) self._write(enc) + def write_nullable_utf(self, utf_string): + self.write_boolean(bool(utf_string)) + + if utf_string: + self.write_utf(utf_string) + def write_utf(self, utf_string): utf = utf_string.encode('utf8') byte_len = len(utf) diff --git a/lavalink/errors.py b/lavalink/errors.py index e4b233dd..09650407 100644 --- a/lavalink/errors.py +++ b/lavalink/errors.py @@ -21,10 +21,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import TYPE_CHECKING, Any, Dict, Optional +if TYPE_CHECKING: + from .player import BasePlayer -class NodeError(Exception): - """ Raised when something went wrong with a node. """ + +class ClientError(Exception): + """ Raised when something goes wrong within the client. """ class AuthenticationError(Exception): @@ -37,3 +41,57 @@ class InvalidTrack(Exception): class LoadError(Exception): """ Raised when a track fails to load. E.g. if a DeferredAudioTrack fails to find an equivalent. """ + + +class RequestError(Exception): + """ + Raised when a request to the Lavalink server fails. + + Attributes + ---------- + status: :class:`int` + The HTTP status code returned by the server. + timestamp: :class:`int` + The epoch timestamp in milliseconds, at which the error occurred. + error: :class:`str` + The HTTP status code message. + message: :class:`str` + The error message. + path: :class:`str` + The request path. + trace: Optional[:class:`str`] + The stack trace of the error. This will only be present if ``trace=true`` was provided + in the query parameters of the request. + params: Dict[str, Any] + The parameters passed to the request that errored. + """ + __slots__ = ('status', 'timestamp', 'error', 'message', 'path', 'trace', 'params') + + def __init__(self, message, status: int, response: dict, params: Dict[str, Any]): + super().__init__(message) + self.status: int = status + self.timestamp: int = response['timestamp'] + self.error: str = response['error'] + self.message: str = response['message'] + self.path: str = response['path'] + self.trace: Optional[str] = response.get('trace', None) + self.params = params + + +class PlayerErrorEvent(Exception): + """ + Raised when an error occurs within a :class:`BasePlayer`. + + Attributes + ---------- + player: :class:`BasePlayer` + The player in which the error occurred. + original: :class:`Exception` + The original error. + """ + __slots__ = ('player', 'original') + + def __init__(self, player, original): + self.player: 'BasePlayer' = player + self.original: Exception = original + # TODO: Perhaps an enum denoting which area of the player encountered an exception, e.g. ErrorType.PLAY. diff --git a/lavalink/events.py b/lavalink/events.py index 9b147b07..1aaad0c0 100644 --- a/lavalink/events.py +++ b/lavalink/events.py @@ -21,12 +21,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from .server import EndReason, Severity if TYPE_CHECKING: - # pylint: disable=cyclic-import - from .models import AudioTrack, BasePlayer, DeferredAudioTrack + from .abc import BasePlayer, DeferredAudioTrack from .node import Node + from .server import AudioTrack class Event: @@ -35,7 +37,7 @@ class Event: class TrackStartEvent(Event): """ - This event is emitted when the player starts to play a track. + This event is emitted when a track begins playing (e.g. via player.play()) Attributes ---------- @@ -53,18 +55,18 @@ def __init__(self, player, track): class TrackStuckEvent(Event): """ - This event is emitted when the currently playing track is stuck. - This normally has something to do with the stream you are playing - and not Lavalink itself. + This event is emitted when the currently playing track is stuck (i.e. has not provided any audio). + This is typically a fault of the track's underlying audio stream, and not Lavalink itself. Attributes ---------- player: :class:`BasePlayer` - The player that has the playing track being stuck. + The player associated with the stuck track. track: :class:`AudioTrack` - The track is stuck from playing. + The stuck track. threshold: :class:`int` - The amount of time the track had while being stuck. + The configured threshold, in milliseconds, after which a track is considered to be stuck + when no audio frames were provided. """ __slots__ = ('player', 'track', 'threshold') @@ -76,7 +78,7 @@ def __init__(self, player, track, threshold): class TrackExceptionEvent(Event): """ - This event is emitted when an exception occurs while playing a track. + This event is emitted when a track encounters an exception during playback. Attributes ---------- @@ -84,18 +86,21 @@ class TrackExceptionEvent(Event): The player that had the exception occur while playing a track. track: :class:`AudioTrack` The track that had the exception while playing. - exception: :class:`str` - The type of exception that the track had while playing. - severity: :class:`str` - The level of severity of the exception. + message: Optional[:class:`str`] + The exception message. + severity: :enum:`Severity` + The severity of the exception. + cause: :class:`str` + The cause of the exception. """ - __slots__ = ('player', 'track', 'exception', 'severity') + __slots__ = ('player', 'track', 'message', 'severity', 'cause') - def __init__(self, player, track, exception, severity): + def __init__(self, player, track, message, severity, cause): self.player: 'BasePlayer' = player self.track: 'AudioTrack' = track - self.exception: str = exception - self.severity: str = severity + self.message: Optional[str] = message + self.severity: Severity = severity + self.cause: str = cause class TrackEndEvent(Event): @@ -109,7 +114,7 @@ class TrackEndEvent(Event): track: Optional[:class:`AudioTrack`] The track that finished playing. This could be ``None`` if Lavalink fails to encode the track. - reason: :class:`str` + reason: :class:`EndReason` The reason why the track stopped playing. """ __slots__ = ('player', 'track', 'reason') @@ -117,7 +122,7 @@ class TrackEndEvent(Event): def __init__(self, player, track, reason): self.player: 'BasePlayer' = player self.track: Optional['AudioTrack'] = track - self.reason: str = reason + self.reason: EndReason = reason class TrackLoadFailedEvent(Event): @@ -251,9 +256,30 @@ def __init__(self, player, old_node, new_node): self.new_node: 'Node' = new_node -class WebSocketClosedEvent(Event): +class NodeReadyEvent(Event): """ + This is a custom event, emitted when a node becomes ready. + A node is considered ready once it receives the "ready" event from the Lavalink server. + Attributes + ---------- + node: :class:`Node` + The node that became ready. + session_id: :class:`str` + The ID of the session. + resumed: :class:`bool` + Whether the session was resumed. This will be false if a brand new session was created. + """ + __slots__ = ('node', 'session_id', 'resumed') + + def __init__(self, node, session_id, resumed): + self.node: 'Node' = node + self.session_id: str = session_id + self.resumed: bool = resumed + + +class WebSocketClosedEvent(Event): + """ This event is emitted when an audio websocket to Discord is closed. This can happen happen for various reasons, an example being when a channel is deleted. @@ -279,3 +305,24 @@ def __init__(self, player, code, reason, by_remote): self.code: int = code self.reason: str = reason self.by_remote: bool = by_remote + + +class IncomingWebSocketMessage(Event): + """ + This event is emitted whenever the client receives a websocket message from the Lavalink server. + + You can use this to extend the functionality of the client, particularly useful when used with + Lavalink server plugins that can add new HTTP routes or websocket messages. + + Attributes + ---------- + data: Union[Dict[Any, Any], List[Any]] + The received JSON-formatted data from the websocket. + node: :class:`Node` + The node responsible for this websocket message. + """ + __slots__ = ('data', 'node') + + def __init__(self, data, node): + self.data: Union[Dict[Any, Any], List[Any]] = data + self.node: 'Node' = node diff --git a/lavalink/filters.py b/lavalink/filters.py index b91ac77a..c4b0704d 100644 --- a/lavalink/filters.py +++ b/lavalink/filters.py @@ -21,20 +21,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Union +# Necessary evil due to documenting filter update kwargs. +# At least, until I can come up with a better solution to doing this. +# pylint: disable=arguments-differ +from typing import List, Tuple, overload -class Filter: - def __init__(self, values: Union[dict, list, float]): - self.values = values - - def update(self, **kwargs): - """ Updates the internal values to match those provided. """ - raise NotImplementedError - - def serialize(self) -> dict: - """ Transforms the internal values into a dict matching the structure Lavalink expects. """ - raise NotImplementedError +from .abc import Filter class Volume(Filter): @@ -44,6 +37,10 @@ class Volume(Filter): def __init__(self): super().__init__(1.0) + @overload + def update(self, volume: float): + ... + def update(self, **kwargs): """ Modifies the player volume. @@ -81,6 +78,14 @@ class Equalizer(Filter): def __init__(self): super().__init__([0.0] * 15) + @overload + def update(self, bands: List[Tuple[int, float]]): + ... + + @overload + def update(self, band: int, gain: int): + ... + def update(self, **kwargs): """ Modifies the gain of each specified band. @@ -111,6 +116,10 @@ def update(self, **kwargs): The band to modify. gain: :class:`float` The new gain of the band. + + Raises + ------ + :class:`ValueError` """ if 'bands' in kwargs: bands = kwargs.pop('bands') @@ -128,8 +137,6 @@ def update(self, **kwargs): elif 'band' in kwargs and 'gain' in kwargs: band = int(kwargs.pop('band')) gain = float(kwargs.pop('gain')) - # don't bother propagating the potential ValueErrors raised by these 2 statements - # The users can handle those. if not 0 <= band <= 14: raise ValueError('Band must be between 0 and 14 (start and end inclusive)') @@ -148,11 +155,39 @@ def serialize(self) -> dict: class Karaoke(Filter): """ Allows for isolating a frequency range (commonly, the vocal range). - Useful for 'karaoke'/sing-along. + Useful for karaoke/sing-along. """ def __init__(self): super().__init__({'level': 1.0, 'monoLevel': 1.0, 'filterBand': 220.0, 'filterWidth': 100.0}) + @overload + def update(self, level: float): + ... + + @overload + def update(self, mono_level: float): + ... + + @overload + def update(self, filter_width: float): + ... + + @overload + def update(self, level: float, mono_level: float): + ... + + @overload + def update(self, level: float, filter_width: float): + ... + + @overload + def update(self, mono_level: float, filter_width: float): + ... + + @overload + def update(self, level: float, mono_level: float, filter_width: float): + ... + def update(self, **kwargs): """ Parameters @@ -165,6 +200,10 @@ def update(self, **kwargs): The frequency of the band to filter. filter_width: :class:`float` The width of the filter. + + Raises + ------ + :class:`ValueError` """ if 'level' in kwargs: self.values['level'] = float(kwargs.pop('level')) @@ -189,6 +228,34 @@ class Timescale(Filter): def __init__(self): super().__init__({'speed': 1.0, 'pitch': 1.0, 'rate': 1.0}) + @overload + def update(self, speed: float): + ... + + @overload + def update(self, pitch: float): + ... + + @overload + def update(self, rate: float): + ... + + @overload + def update(self, speed: float, pitch: float): + ... + + @overload + def update(self, speed: float, rate: float): + ... + + @overload + def update(self, rate: float, pitch: float): + ... + + @overload + def update(self, speed: float, rate: float, pitch: float): + ... + def update(self, **kwargs): """ Note @@ -209,6 +276,10 @@ def update(self, **kwargs): The pitch of the audio. rate: :class:`float` The playback rate. + + Raises + ------ + :class:`ValueError` """ if 'speed' in kwargs: speed = float(kwargs.pop('speed')) @@ -245,6 +316,18 @@ class Tremolo(Filter): def __init__(self): super().__init__({'frequency': 2.0, 'depth': 0.5}) + @overload + def update(self, frequency: float): + ... + + @overload + def update(self, depth: float): + ... + + @overload + def update(self, frequency: float, depth: float): + ... + def update(self, **kwargs): """ Note @@ -261,6 +344,10 @@ def update(self, **kwargs): How frequently the effect should occur. depth: :class:`float` The "strength" of the effect. + + Raises + ------ + :class:`ValueError` """ if 'frequency' in kwargs: frequency = float(kwargs.pop('frequency')) @@ -289,6 +376,18 @@ class Vibrato(Filter): def __init__(self): super().__init__({'frequency': 2.0, 'depth': 0.5}) + @overload + def update(self, frequency: float): + ... + + @overload + def update(self, depth: float): + ... + + @overload + def update(self, frequency: float, depth: float): + ... + def update(self, **kwargs): """ Note @@ -305,6 +404,10 @@ def update(self, **kwargs): How frequently the effect should occur. depth: :class:`float` The "strength" of the effect. + + Raises + ------ + :class:`ValueError` """ if 'frequency' in kwargs: frequency = float(kwargs.pop('frequency')) @@ -332,7 +435,11 @@ class Rotation(Filter): This is commonly used to create the 8D effect. """ def __init__(self): - super().__init__({'rotationHz': 0.0}) + super().__init__(0.0) + + @overload + def update(self, rotation_hz: float): + ... def update(self, **kwargs): """ @@ -346,17 +453,21 @@ def update(self, **kwargs): ---------- rotation_hz: :class:`float` How frequently the effect should occur. + + Raises + ------ + :class:`ValueError` """ if 'rotation_hz' in kwargs: rotation_hz = float(kwargs.pop('rotation_hz')) if rotation_hz < 0: - raise ValueError('rotationHz must be bigger than or equal to 0') + raise ValueError('rotation_hz must be bigger than or equal to 0') - self.values['rotationHz'] = rotation_hz + self.values = rotation_hz def serialize(self) -> dict: - return {'rotation': self.values} + return {'rotation': {'rotationHz': self.values}} class LowPass(Filter): @@ -365,7 +476,11 @@ class LowPass(Filter): effectively cutting off high frequencies meaning more emphasis is put on lower frequencies. """ def __init__(self): - super().__init__({'smoothing': 20.0}) + super().__init__(20.0) + + @overload + def update(self, smoothing: float): + ... def update(self, **kwargs): """ @@ -379,6 +494,10 @@ def update(self, **kwargs): ---------- smoothing: :class:`float` The strength of the effect. + + Raises + ------ + :class:`ValueError` """ if 'smoothing' in kwargs: smoothing = float(kwargs.pop('smoothing')) @@ -386,10 +505,10 @@ def update(self, **kwargs): if smoothing <= 1: raise ValueError('smoothing must be bigger than 1') - self.values['smoothing'] = smoothing + self.values = smoothing def serialize(self) -> dict: - return {'lowPass': self.values} + return {'lowPass': {'smoothing': self.values}} class ChannelMix(Filter): @@ -424,6 +543,10 @@ def update(self, **kwargs): The volume level of the audio going from the "Right" channel to the "Left" channel. right_to_right: :class:`float` The volume level of the audio going from the "Right" channel to the "Left" channel. + + Raises + ------ + :class:`ValueError` """ if 'left_to_left' in kwargs: left_to_left = float(kwargs.pop('left_to_left')) @@ -489,6 +612,10 @@ def update(self, **kwargs): The sin offset. scale: :class:`float` The sin scale. + + Raises + ------ + :class:`ValueError` """ if 'sin_offset' in kwargs: self.values['sinOffset'] = float(kwargs.pop('sin_offset')) diff --git a/lavalink/models.py b/lavalink/models.py deleted file mode 100644 index 07f8138c..00000000 --- a/lavalink/models.py +++ /dev/null @@ -1,1183 +0,0 @@ -""" -MIT License - -Copyright (c) 2017-present Devoxin - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" -# pylint: disable=too-many-lines -from abc import ABC, abstractmethod -from enum import Enum -from random import randrange -from time import time -from typing import TYPE_CHECKING, Dict, List, Optional, Union - -from .errors import InvalidTrack, LoadError -from .events import (NodeChangedEvent, QueueEndEvent, TrackEndEvent, - TrackExceptionEvent, TrackLoadFailedEvent, - TrackStartEvent, TrackStuckEvent) -from .filters import Equalizer, Filter - -if TYPE_CHECKING: - # pylint: disable=cyclic-import - from .client import Client - from .node import Node - - -class AudioTrack: - """ - Represents an AudioTrack. - - Parameters - ---------- - data: Union[:class:`dict`, :class:`AudioTrack`] - The data to initialise an AudioTrack from. - requester: :class:`any` - The requester of the track. - extra: :class:`dict` - Any extra information to store in this AudioTrack. - - Attributes - ---------- - track: Optional[:class:`str`] - The base64-encoded string representing a Lavalink-readable AudioTrack. - This is marked optional as it could be None when it's not set by a custom :class:`Source`, - which is expected behaviour when the subclass is a :class:`DeferredAudioTrack`. - identifier: :class:`str` - The track's id. For example, a youtube track's identifier will look like dQw4w9WgXcQ. - is_seekable: :class:`bool` - Whether the track supports seeking. - author: :class:`str` - The track's uploader. - duration: :class:`int` - The duration of the track, in milliseconds. - stream: :class:`bool` - Whether the track is a live-stream. - title: :class:`str` - The title of the track. - uri: :class:`str` - The full URL of track. - position: :class:`int` - The playback position of the track, in milliseconds. - This is a read-only property; setting it won't have any effect. - source_name: :class:`str` - The name of the source that this track was created by. - requester: :class:`int` - The ID of the user that requested this track. - extra: :class:`dict` - Any extra properties given to this AudioTrack will be stored here. - """ - __slots__ = ('_raw', 'track', 'identifier', 'is_seekable', 'author', 'duration', 'stream', 'title', 'uri', - 'position', 'source_name', 'extra') - - def __init__(self, data: dict, requester: int, **extra): - try: - if isinstance(data, AudioTrack): - extra = {**data.extra, **extra} - data = data._raw - - self._raw = data - - info = data.get('info', data) - self.track: Optional[str] = data.get('track') - self.identifier: str = info['identifier'] - self.is_seekable: bool = info['isSeekable'] - self.author: str = info['author'] - self.duration: int = info['length'] - self.stream: bool = info['isStream'] - self.title: str = info['title'] - self.uri: str = info['uri'] - self.position: int = info.get('position', 0) - self.source_name: str = info.get('sourceName', 'unknown') - self.extra: dict = {**extra, 'requester': requester} - except KeyError as ke: - missing_key, = ke.args - raise InvalidTrack('Cannot build a track from partial data! (Missing key: {})'.format(missing_key)) from None - - def __getitem__(self, name): - if name == 'info': - return self - - return super().__getattribute__(name) - - @property - def requester(self) -> int: - return self.extra['requester'] - - @requester.setter - def requester(self, requester) -> int: - self.extra['requester'] = requester - - def __repr__(self): - return ''.format(self) - - -class DeferredAudioTrack(ABC, AudioTrack): - """ - Similar to an :class:`AudioTrack`, however this track only stores metadata up until it's - played, at which time :func:`load` is called to retrieve a base64 string which is then used for playing. - - Note - ---- - For implementation: The ``track`` field need not be populated as this is done later via - the :func:`load` method. You can optionally set ``self.track`` to the result of :func:`load` - during implementation, as a means of caching the base64 string to avoid fetching it again later. - This should serve the purpose of speeding up subsequent play calls in the event of repeat being enabled, - for example. - """ - @abstractmethod - async def load(self, client: 'Client'): - """|coro| - - Retrieves a base64 string that's playable by Lavalink. - For example, you can use this method to search Lavalink for an identical track from other sources, - which you can then use the base64 string of to play the track on Lavalink. - - Parameters - ---------- - client: :class:`Client` - This will be an instance of the Lavalink client 'linked' to this track. - - Returns - ------- - :class:`str` - A Lavalink-compatible base64 string containing encoded track metadata. - """ - raise NotImplementedError - - -class LoadType(Enum): - TRACK = 'TRACK_LOADED' - PLAYLIST = 'PLAYLIST_LOADED' - SEARCH = 'SEARCH_RESULT' - NO_MATCHES = 'NO_MATCHES' - LOAD_FAILED = 'LOAD_FAILED' - - def __eq__(self, other): - if self.__class__ is other.__class__: - return self.value == other.value # pylint: disable=comparison-with-callable - - if isinstance(other, str): - return self.value == other # pylint: disable=comparison-with-callable - - raise NotImplementedError - - @classmethod - def from_str(cls, other: str): - try: - return cls[other.upper()] - except KeyError: - try: - return cls(other.upper()) - except ValueError as ve: - raise ValueError('{} is not a valid LoadType enum!'.format(other)) from ve - - -class PlaylistInfo: - """ - Attributes - ---------- - name: :class:`str` - The name of the playlist. - selected_track: :class:`int` - The index of the selected/highlighted track. - This will be -1 if there is no selected track. - """ - def __init__(self, name: str, selected_track: int = -1): - self.name: str = name - self.selected_track: int = selected_track - - def __getitem__(self, k): # Exists only for compatibility, don't blame me - if k == 'selectedTrack': - k = 'selected_track' - return self.__getattribute__(k) - - @classmethod - def from_dict(cls, mapping: dict): - return cls(mapping.get('name'), mapping.get('selectedTrack', -1)) - - @classmethod - def none(cls): - return cls('', -1) - - def __repr__(self): - return ''.format(self) - - -class LoadResult: - """ - Attributes - ---------- - load_type: :class:`LoadType` - The load type of this result. - tracks: List[Union[:class:`AudioTrack`, :class:`DeferredAudioTrack`]] - The tracks in this result. - playlist_info: :class:`PlaylistInfo` - The playlist metadata for this result. - The :class:`PlaylistInfo` could contain empty/false data if the :class:`LoadType` - is not :enum:`LoadType.PLAYLIST`. - """ - def __init__(self, load_type: LoadType, tracks: List[Union[AudioTrack, DeferredAudioTrack]], - playlist_info: Optional[PlaylistInfo] = PlaylistInfo.none()): - self.load_type: LoadType = load_type - self.playlist_info: PlaylistInfo = playlist_info - self.tracks: List[Union[AudioTrack, DeferredAudioTrack]] = tracks - - def __getitem__(self, k): # Exists only for compatibility, don't blame me - if k == 'loadType': - k = 'load_type' - elif k == 'playlistInfo': - k = 'playlist_info' - - return self.__getattribute__(k) - - @classmethod - def from_dict(cls, mapping: dict): - load_type = LoadType.from_str(mapping.get('loadType')) - playlist_info = PlaylistInfo.from_dict(mapping.get('playlistInfo')) - tracks = [AudioTrack(track, 0) for track in mapping.get('tracks')] - return cls(load_type, tracks, playlist_info) - - @property - def selected_track(self) -> Optional[AudioTrack]: - """ - Convenience method for returning the selected track using - :attr:`PlaylistInfo.selected_track`. - - This could be ``None`` if :attr:`playlist_info` is ``None``, - or :attr:`PlaylistInfo.selected_track` is an invalid number. - - Returns - ------- - Optional[:class:`AudioTrack`] - """ - if self.playlist_info is not None: - index = self.playlist_info.selected_track - - if 0 <= index < len(self.tracks): - return self.tracks[index] - - return None - - def __repr__(self): - return ''.format(self, len(self.tracks)) - - -class Source(ABC): - def __init__(self, name: str): - self.name: str = name - - def __eq__(self, other): - if self.__class__ is other.__class__: - return self.name == other.name - - raise NotImplementedError - - def __hash__(self): - return hash(self.name) - - @abstractmethod - async def load_item(self, client: 'Client', query: str) -> Optional[LoadResult]: - """|coro| - - Loads a track with the given query. - - Parameters - ---------- - client: :class:`Client` - The Lavalink client. This could be useful for performing a Lavalink search - for an identical track from other sources, if needed. - query: :class:`str` - The search query that was provided. - - Returns - ------- - Optional[:class:`LoadResult`] - A LoadResult, or None if there were no matches for the provided query. - """ - raise NotImplementedError - - def __repr__(self): - return ''.format(self) - - -class BasePlayer(ABC): - """ - Represents the BasePlayer all players must be inherited from. - - Attributes - ---------- - guild_id: :class:`int` - The guild id of the player. - node: :class:`Node` - The node that the player is connected to. - channel_id: Optional[:class:`int`] - The ID of the voice channel the player is connected to. - This could be None if the player isn't connected. - """ - def __init__(self, guild_id: int, node: 'Node'): - self._lavalink = node._manager._lavalink - self.guild_id: int = guild_id - self._internal_id: str = str(guild_id) - self.node: 'Node' = node - self._original_node: Optional['Node'] = None # This is used internally for failover. - self._voice_state = {} - self.channel_id: Optional[int] = None - - @abstractmethod - async def _handle_event(self, event): - raise NotImplementedError - - @abstractmethod - async def _update_state(self, state: dict): - raise NotImplementedError - - async def play_track(self, track: str, start_time: Optional[int] = None, end_time: Optional[int] = None, - no_replace: Optional[bool] = None, volume: Optional[int] = None, pause: Optional[bool] = None): - """|coro| - - Plays the given track. - - Parameters - ---------- - track: :class:`str` - The track to play. This must be the base64 string from a track. - start_time: Optional[:class:`int`] - The number of milliseconds to offset the track by. - If left unspecified or ``None`` is provided, the track will start from the beginning. - end_time: Optional[:class:`int`] - The position at which the track should stop playing. - This is an absolute position, so if you want the track to stop at 1 minute, you would pass 60000. - The default behaviour is to play until no more data is received from the remote server. - If left unspecified or ``None`` is provided, the default behaviour is exhibited. - no_replace: Optional[:class:`bool`] - If set to true, operation will be ignored if a track is already playing or paused. - The default behaviour is to always replace. - If left unspecified or None is provided, the default behaviour is exhibited. - volume: Optional[:class:`int`] - The initial volume to set. This is useful for changing the volume between tracks etc. - If left unspecified or ``None`` is provided, the volume will remain at its current setting. - pause: Optional[:class:`bool`] - Whether to immediately pause the track after loading it. - The default behaviour is to never pause. - If left unspecified or ``None`` is provided, the default behaviour is exhibited. - """ - if track is None or not isinstance(track, str): - raise ValueError('track must be a str') - - options = {} - - if start_time is not None: - if not isinstance(start_time, int) or start_time < 0: - raise ValueError('start_time must be an int with a value equal to, or greater than 0') - options['startTime'] = start_time - - if end_time is not None: - if not isinstance(end_time, int) or not end_time >= 1: - raise ValueError('end_time must be an int with a value equal to, or greater than 1') - - if end_time > 0: - options['endTime'] = end_time - - if no_replace is not None: - if not isinstance(no_replace, bool): - raise TypeError('no_replace must be a bool') - options['noReplace'] = no_replace - - if volume is not None: - if not isinstance(volume, int): - raise TypeError('volume must be an int') - self.volume = max(min(volume, 1000), 0) - options['volume'] = self.volume - - if pause is not None: - if not isinstance(pause, bool): - raise TypeError('pause must be a bool') - options['pause'] = pause - - await self.node._send(op='play', guildId=self._internal_id, track=track, **options) - - def cleanup(self): - pass - - async def destroy(self): - """|coro| - - Destroys the current player instance. - - Shortcut for :func:`PlayerManager.destroy`. - """ - await self._lavalink.player_manager.destroy(self.guild_id) - - async def _voice_server_update(self, data): - self._voice_state.update({ - 'event': data - }) - - await self._dispatch_voice_update() - - async def _voice_state_update(self, data): - raw_channel_id = data['channel_id'] - self.channel_id = int(raw_channel_id) if raw_channel_id else None - - if not self.channel_id: # We're disconnecting - self._voice_state.clear() - return - - if data['session_id'] != self._voice_state.get('sessionId'): - self._voice_state.update({ - 'sessionId': data['session_id'] - }) - - await self._dispatch_voice_update() - - async def _dispatch_voice_update(self): - if {'sessionId', 'event'} == self._voice_state.keys(): - await self.node._send(op='voiceUpdate', guildId=self._internal_id, **self._voice_state) - - @abstractmethod - async def node_unavailable(self): - """|coro| - - Called when a player's node becomes unavailable. - Useful for changing player state before it's moved to another node. - """ - raise NotImplementedError - - @abstractmethod - async def change_node(self, node: 'Node'): - """|coro| - - Called when a node change is requested for the current player instance. - - Parameters - ---------- - node: :class:`Node` - The new node to switch to. - """ - raise NotImplementedError - - -class DefaultPlayer(BasePlayer): - """ - The player that Lavalink.py uses by default. - - This should be sufficient for most use-cases. - - Attributes - ---------- - LOOP_NONE: :class:`int` - Class attribute. Disables looping entirely. - LOOP_SINGLE: :class:`int` - Class attribute. Enables looping for a single (usually currently playing) track only. - LOOP_QUEUE: :class:`int` - Class attribute. Enables looping for the entire queue. When a track finishes playing, it'll be added to the end of the queue. - guild_id: :class:`int` - The guild id of the player. - node: :class:`Node` - The node that the player is connected to. - paused: :class:`bool` - Whether or not a player is paused. - position_timestamp: :class:`int` - Returns the track's elapsed playback time as an epoch timestamp. - volume: :class:`int` - The volume at which the player is playing at. - shuffle: :class:`bool` - Whether or not to mix the queue up in a random playing order. - loop: :class:`int` - Whether loop is enabled, and the type of looping. - This is an integer as loop supports multiple states. - - 0 = Loop off. - - 1 = Loop track. - - 2 = Loop queue. - - Example - ------- - .. code:: python - - if player.loop == player.LOOP_NONE: - await ctx.send('Not looping.') - elif player.loop == player.LOOP_SINGLE: - await ctx.send(f'{player.current.title} is looping.') - elif player.loop == player.LOOP_QUEUE: - await ctx.send('This queue never ends!') - filters: Dict[:class:`str`, :class:`Filter`] - A mapping of str to :class:`Filter`, representing currently active filters. - queue: List[:class:`AudioTrack`] - A list of AudioTracks to play. - current: Optional[:class:`AudioTrack`] - The track that is playing currently, if any. - """ - LOOP_NONE: int = 0 - LOOP_SINGLE: int = 1 - LOOP_QUEUE: int = 2 - - def __init__(self, guild_id: int, node: 'Node'): - super().__init__(guild_id, node) - - self._user_data = {} - - self.paused: bool = False - self._internal_pause: bool = False # Toggled when player's node becomes unavailable, primarily used for track position tracking. - self._last_update = 0 - self._last_position = 0 - self.position_timestamp: int = 0 - self.volume: int = 100 - self.shuffle: bool = False - self.loop: int = 0 # 0 = off, 1 = single track, 2 = queue - self.filters: Dict[str, Filter] = {} - - self.queue: List[AudioTrack] = [] - self.current: Optional[AudioTrack] = None - - @property - def repeat(self) -> bool: - """ - Returns the player's loop status. This exists for backwards compatibility, and also as an alias. - - .. deprecated:: 4.0.0 - Use :attr:`loop` instead. - - If ``self.loop`` is 0, the player is NOT looping. - - If ``self.loop`` is 1, the player is looping the single (current) track. - - If ``self.loop`` is 2, the player is looping the entire queue. - """ - return self.loop == 1 or self.loop == 2 - - @property - def is_playing(self) -> bool: - """ Returns the player's track state. """ - return self.is_connected and self.current is not None - - @property - def is_connected(self) -> bool: - """ Returns whether the player is connected to a voicechannel or not. """ - return self.channel_id is not None - - @property - def position(self) -> float: - """ Returns the track's elapsed playback time in milliseconds, adjusted for Lavalink stat interval. """ - if not self.is_playing: - return 0 - - if self.paused or self._internal_pause: - return min(self._last_position, self.current.duration) - - difference = time() * 1000 - self._last_update - return min(self._last_position + difference, self.current.duration) - - def store(self, key: object, value: object): - """ - Stores custom user data. - - Parameters - ---------- - key: :class:`object` - The key of the object to store. - value: :class:`object` - The object to associate with the key. - """ - self._user_data.update({key: value}) - - def fetch(self, key: object, default=None): - """ - Retrieves the related value from the stored user data. - - Parameters - ---------- - key: :class:`object` - The key to fetch. - default: Optional[:class:`any`] - The object that should be returned if the key doesn't exist. Defaults to ``None``. - - Returns - ------- - Optional[:class:`any`] - """ - return self._user_data.get(key, default) - - def delete(self, key: object): - """ - Removes an item from the the stored user data. - - Parameters - ---------- - key: :class:`object` - The key to delete. - - Raises - ------ - :class:`KeyError` - If the key doesn't exist. - """ - try: - del self._user_data[key] - except KeyError: - pass - - def add(self, track: Union[AudioTrack, DeferredAudioTrack, Dict], requester: int = 0, index: int = None): - """ - Adds a track to the queue. - - Parameters - ---------- - track: Union[:class:`AudioTrack`, :class:`DeferredAudioTrack`, :class:`dict`] - The track to add. Accepts either an AudioTrack or - a dict representing a track returned from Lavalink. - requester: :class:`int` - The ID of the user who requested the track. - index: Optional[:class:`int`] - The index at which to add the track. - If index is left unspecified, the default behaviour is to append the track. Defaults to ``None``. - """ - at = track - - if isinstance(track, dict): - at = AudioTrack(track, requester) - - if requester != 0: - at.requester = requester - - if index is None: - self.queue.append(at) - else: - self.queue.insert(index, at) - - async def play(self, track: Optional[Union[AudioTrack, DeferredAudioTrack, Dict]] = None, start_time: Optional[int] = 0, - end_time: Optional[int] = None, no_replace: Optional[bool] = False, volume: Optional[int] = None, - pause: Optional[bool] = False): - """|coro| - - Plays the given track. - - This method differs from :meth:`BasePlayer.play_track` in that it contains additional logic - to handle certain attributes, such as ``loop``, ``shuffle``, and loading a base64 string from :class:`DeferredAudioTrack`. - - :meth:`BasePlayer.play_track` is a no-frills, raw function which will unconditionally tell the node to play exactly whatever - it is passed. - - Parameters - ---------- - track: Optional[Union[:class:`DeferredAudioTrack`, :class:`AudioTrack`, :class:`dict`]] - The track to play. If left unspecified, this will default - to the first track in the queue. Defaults to ``None`` so plays the next - song in queue. Accepts either an AudioTrack or a dict representing a track - returned from Lavalink. - start_time: Optional[:class:`int`] - The number of milliseconds to offset the track by. - If left unspecified or ``None`` is provided, the track will start from the beginning. - end_time: Optional[:class:`int`] - The position at which the track should stop playing. - This is an absolute position, so if you want the track to stop at 1 minute, you would pass 60000. - The default behaviour is to play until no more data is received from the remote server. - If left unspecified or ``None`` is provided, the default behaviour is exhibited. - no_replace: Optional[:class:`bool`] - If set to true, operation will be ignored if a track is already playing or paused. - The default behaviour is to always replace. - If left unspecified or None is provided, the default behaviour is exhibited. - volume: Optional[:class:`int`] - The initial volume to set. This is useful for changing the volume between tracks etc. - If left unspecified or ``None`` is provided, the volume will remain at its current setting. - pause: Optional[:class:`bool`] - Whether to immediately pause the track after loading it. - The default behaviour is to never pause. - If left unspecified or ``None`` is provided, the default behaviour is exhibited. - - Raises - ------ - :class:`ValueError` - If invalid values were provided for ``start_time`` or ``end_time``. - :class:`TypeError` - If wrong types were provided for ``no_replace``, ``volume`` or ``pause``. - """ - if no_replace and self.is_playing: - return - - if track is not None and isinstance(track, dict): - track = AudioTrack(track, 0) - - if self.loop > 0 and self.current: - if self.loop == 1: - if track is not None: - self.queue.insert(0, self.current) - else: - track = self.current - if self.loop == 2: - self.queue.append(self.current) - - self._last_position = 0 - self.position_timestamp = 0 - self.paused = pause - - if not track: - if not self.queue: - await self.stop() # Also sets current to None. - await self.node._dispatch_event(QueueEndEvent(self)) - return - - pop_at = randrange(len(self.queue)) if self.shuffle else 0 - track = self.queue.pop(pop_at) - - if start_time is not None: - if not isinstance(start_time, int) or not 0 <= start_time < track.duration: - raise ValueError('start_time must be an int with a value equal to, or greater than 0, and less than the track duration') - - if end_time is not None: - if not isinstance(end_time, int) or not 1 <= end_time <= track.duration: - raise ValueError('end_time must be an int with a value equal to, or greater than 1, and less than, or equal to the track duration') - - self.current = track - playable_track = track.track - - if playable_track is None: - if not isinstance(track, DeferredAudioTrack): - raise InvalidTrack('Cannot play the AudioTrack as \'track\' is None, and it is not a DeferredAudioTrack!') - - try: - playable_track = await track.load(self.node._manager._lavalink) - except LoadError as load_error: - await self.node._dispatch_event(TrackLoadFailedEvent(self, track, load_error)) - - if playable_track is None: # This should only fire when a DeferredAudioTrack fails to yield a base64 track string. - await self.node._dispatch_event(TrackLoadFailedEvent(self, track, None)) - return - - await self.play_track(playable_track, start_time, end_time, no_replace, volume, pause) - await self.node._dispatch_event(TrackStartEvent(self, track)) - # TODO: Figure out a better solution for the above. Custom player implementations may neglect - # to dispatch TrackStartEvent leading to confusion and poor user experience. - - async def stop(self): - """|coro| - - Stops the player. - """ - await self.node._send(op='stop', guildId=self._internal_id) - self.current = None - - async def skip(self): - """|coro| - - Plays the next track in the queue, if any. - """ - await self.play() - - def set_repeat(self, repeat: bool): - """ - Sets whether tracks should be repeated. - - .. deprecated:: 4.0.0 - Use :func:`set_loop` to repeat instead. - - This only works as a "queue loop". For single-track looping, you should - utilise the :class:`TrackEndEvent` event to feed the track back into - :func:`play`. - - Also known as ``loop``. - - Parameters - ---------- - repeat: :class:`bool` - Whether to repeat the player or not. - """ - self.loop = 2 if repeat else 0 - - def set_loop(self, loop: int): - """ - Sets whether the player loops between a single track, queue or none. - - 0 = off, 1 = single track, 2 = queue. - - Parameters - ---------- - loop: :class:`int` - The loop setting. 0 = off, 1 = single track, 2 = queue. - """ - if not 0 <= loop <= 2: - raise ValueError('Loop must be 0, 1 or 2.') - - self.loop = loop - - def set_shuffle(self, shuffle: bool): - """ - Sets the player's shuffle state. - - Parameters - ---------- - shuffle: :class:`bool` - Whether to shuffle the player or not. - """ - self.shuffle = shuffle - - async def set_pause(self, pause: bool): - """|coro| - - Sets the player's paused state. - - Parameters - ---------- - pause: :class:`bool` - Whether to pause the player or not. - """ - self.paused = pause - await self.node._send(op='pause', guildId=self._internal_id, pause=pause) - - async def set_volume(self, vol: int): - """|coro| - - Sets the player's volume - - Note - ---- - A limit of 1000 is imposed by Lavalink. - - Parameters - ---------- - vol: :class:`int` - The new volume level. - """ - self.volume = max(min(vol, 1000), 0) - await self.node._send(op='volume', guildId=self._internal_id, volume=self.volume) - - async def seek(self, position: int): - """|coro| - - Seeks to a given position in the track. - - Parameters - ---------- - position: :class:`int` - The new position to seek to in milliseconds. - """ - await self.node._send(op='seek', guildId=self._internal_id, position=position) - - async def set_filter(self, _filter: Filter): - """|coro| - - Applies the corresponding filter within Lavalink. - This will overwrite the filter if it's already applied. - - Example - ------- - .. code:: python - - equalizer = Equalizer() - equalizer.update(bands=[(0, 0.2), (1, 0.3), (2, 0.17)]) - player.set_filter(equalizer) - - Parameters - ---------- - _filter: :class:`Filter` - The filter instance to set. - - Raises - ------ - :class:`TypeError` - If the provided ``_filter`` is not of type :class:`Filter`. - """ - if not isinstance(_filter, Filter): - raise TypeError('Expected object of type Filter, not ' + type(_filter).__name__) - - filter_name = type(_filter).__name__.lower() - self.filters[filter_name] = _filter - await self._apply_filters() - - async def update_filter(self, _filter: Filter, **kwargs): - """|coro| - - Updates a filter using the upsert method; - if the filter exists within the player, its values will be updated; - if the filter does not exist, it will be created with the provided values. - - This will not overwrite any values that have not been provided. - - Example - ------- - .. code :: python - - player.update_filter(Timescale, speed=1.5) - # This means that, if the Timescale filter is already applied - # and it already has set values of "speed=1, pitch=1.2", pitch will remain - # the same, however speed will be changed to 1.5 so the result is - # "speed=1.5, pitch=1.2" - - Parameters - ---------- - _filter: :class:`Filter` - The filter class (**not** an instance of, see above example) to upsert. - **kwargs: :class:`any` - The kwargs to pass to the filter. - - Raises - ------ - :class:`TypeError` - If the provided ``_filter`` is not of type :class:`Filter`. - """ - if isinstance(_filter, Filter): - raise TypeError('Expected class of type Filter, not an instance of ' + type(_filter).__name__) - - if not issubclass(_filter, Filter): - raise TypeError('Expected subclass of type Filter, not ' + _filter.__name__) - - filter_name = _filter.__name__.lower() - - filter_instance = self.filters.get(filter_name, _filter()) - filter_instance.update(**kwargs) - self.filters[filter_name] = filter_instance - await self._apply_filters() - - def get_filter(self, _filter: Union[Filter, str]): - """ - Returns the corresponding filter, if it's enabled. - - Example - ------- - .. code:: python - - from lavalink.filters import Timescale - timescale = player.get_filter(Timescale) - # or - timescale = player.get_filter('timescale') - - Parameters - ---------- - _filter: Union[:class:`Filter`, :class:`str`] - The filter name, or filter class (**not** an instance of, see above example), to get. - - Returns - ------- - Optional[:class:`Filter`] - """ - if isinstance(_filter, str): - filter_name = _filter - elif isinstance(_filter, Filter): # User passed an instance of. - filter_name = type(_filter).__name__ - else: - if not issubclass(_filter, Filter): - raise TypeError('Expected subclass of type Filter, not ' + _filter.__name__) - - filter_name = _filter.__name__ - - return self.filters.get(filter_name.lower(), None) - - async def remove_filter(self, _filter: Union[Filter, str]): - """|coro| - - Removes a filter from the player, undoing any effects applied to the audio. - - Example - ------- - .. code:: python - - player.remove_filter(Timescale) - # or - player.remove_filter('timescale') - - Parameters - ---------- - _filter: Union[:class:`Filter`, :class:`str`] - The filter name, or filter class (**not** an instance of, see above example), to remove. - """ - if isinstance(_filter, str): - filter_name = _filter - elif isinstance(_filter, Filter): # User passed an instance of. - filter_name = type(_filter).__name__ - else: - if not issubclass(_filter, Filter): - raise TypeError('Expected subclass of type Filter, not ' + _filter.__name__) - - filter_name = _filter.__name__ - - fn_lowered = filter_name.lower() - - if fn_lowered in self.filters: - self.filters.pop(fn_lowered) - await self._apply_filters() - - async def clear_filters(self): - """|coro| - - Clears all currently-enabled filters. - """ - self.filters.clear() - await self._apply_filters() - - async def set_gain(self, band: int, gain: float = 0.0): - """|coro| - - Sets the equalizer band gain to the given amount. - - .. deprecated:: 4.0.0 - Use :func:`set_filter` to apply the :class:`Equalizer` filter instead. - - Parameters - ---------- - band: :class:`int` - Band number (0-14). - gain: Optional[:class:`float`] - A float representing gain of a band (-0.25 to 1.00). Defaults to 0.0. - """ - await self.set_gains((band, gain)) - - async def set_gains(self, *bands): - """|coro| - - Modifies the player's equalizer settings. - - .. deprecated:: 4.0.0 - Use :func:`set_filter` to apply the :class:`Equalizer` filter instead. - - Parameters - ---------- - gain_list: :class:`any` - A list of tuples denoting (``band``, ``gain``). - """ - equalizer = Equalizer() - equalizer.update(bands=bands) - await self.set_filter(equalizer) - - async def reset_equalizer(self): - """|coro| - - Resets equalizer to default values. - - .. deprecated:: 4.0.0 - Use :func:`remove_filter` to remove the :class:`Equalizer` filter instead. - """ - await self.remove_filter(Equalizer) - - async def _apply_filters(self): - payload = {} - - for _filter in self.filters.values(): - payload.update(_filter.serialize()) - - await self.node._send(op='filters', guildId=self._internal_id, **payload) - - async def _handle_event(self, event): - """ - Handles the given event as necessary. - - Parameters - ---------- - event: :class:`Event` - The event that will be handled. - """ - if isinstance(event, (TrackStuckEvent, TrackExceptionEvent)) or \ - isinstance(event, TrackEndEvent) and event.reason == 'FINISHED': - await self.play() - - async def _update_state(self, state: dict): - """ - Updates the position of the player. - - Parameters - ---------- - state: :class:`dict` - The state that is given to update. - """ - self._last_update = time() * 1000 - self._last_position = state.get('position', 0) - self.position_timestamp = state.get('time', 0) - - async def node_unavailable(self): - """|coro| - - Called when a player's node becomes unavailable. - Useful for changing player state before it's moved to another node. - """ - self._internal_pause = True - - async def change_node(self, node): - """|coro| - - Changes the player's node - - Parameters - ---------- - node: :class:`Node` - The node the player is changed to. - """ - if self.node.available: - await self.node._send(op='destroy', guildId=self._internal_id) - - old_node = self.node - self.node = node - - if self._voice_state: - await self._dispatch_voice_update() - - if self.current: - playable_track = self.current.track - - if isinstance(self.current, DeferredAudioTrack) and playable_track is None: - playable_track = await self.current.load(self.node._manager._lavalink) - - await self.node._send(op='play', guildId=self._internal_id, track=playable_track, startTime=self.position) - self._last_update = time() * 1000 - - if self.paused: - await self.node._send(op='pause', guildId=self._internal_id, pause=self.paused) - - self._internal_pause = False - - if self.volume != 100: - await self.node._send(op='volume', guildId=self._internal_id, volume=self.volume) - - if self.filters: - await self._apply_filters() - - await self.node._dispatch_event(NodeChangedEvent(self, old_node, node)) - - def __repr__(self): - return ''.format(self) - - -class Plugin: - """ - Represents a Lavalink server plugin. - - Parameters - ---------- - data: :class:`dict` - The data to initialise a Plugin from. - - Attributes - ---------- - name: :class:`str` - The name of the plugin. - version: :class:`str` - The version of the plugin. - """ - __slots__ = ('name', 'version') - - def __init__(self, data: dict): - self.name: str = data['name'] - self.version: str = data['version'] - - def __str__(self): - return '{0.name} v{0.version}'.format(self) - - def __repr__(self): - return ''.format(self) diff --git a/lavalink/node.py b/lavalink/node.py index dfb708bc..4aeb5369 100644 --- a/lavalink/node.py +++ b/lavalink/node.py @@ -21,12 +21,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import List +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload -from .events import Event -from .models import BasePlayer, LoadResult, Plugin # noqa: F401 +from .abc import BasePlayer, Filter +from .common import MISSING +from .errors import ClientError, RequestError +from .server import AudioTrack, LoadResult from .stats import Stats -from .websocket import WebSocket +from .transport import Transport + +if TYPE_CHECKING: + from .client import Client + from .nodemanager import NodeManager class Node: @@ -35,49 +42,41 @@ class Node: Note ---- - Nodes are **NOT** mean't to be added manually, but rather with :func:`Client.add_node`. Doing this can cause - invalid cache and much more problems. + To construct a node, you should use :func:`Client.add_node` instead. Attributes ---------- - host: :class:`str` - The address of the Lavalink node. - port: :class:`int` - The port to use for websocket and REST connections. - password: :class:`str` - The password used for authentication. - ssl: :class:`bool` - Whether this node uses SSL (wss/https). + client: :class:`Client` + The Lavalink client. region: :class:`str` The region to assign this node to. name: :class:`str` The name the :class:`Node` is identified by. - filters: :class:`bool` - Whether or not to use the new ``filters`` op instead of ``equalizer``. - This setting is only used by players. stats: :class:`Stats` The statistics of how the :class:`Node` is performing. """ - def __init__(self, manager, host: str, port: int, password: str, - region: str, resume_key: str, resume_timeout: int, name: str = None, - reconnect_attempts: int = 3, filters: bool = False, ssl: bool = False): - self._lavalink = manager._lavalink - self._manager = manager - self._ws = WebSocket(self, host, port, password, ssl, resume_key, resume_timeout, reconnect_attempts) - - self.host: str = host - self.port: int = port - self.password: str = password - self.ssl: bool = ssl + __slots__ = ('client', 'manager', '_transport', 'region', 'name', 'stats') + + def __init__(self, manager, host: str, port: int, password: str, region: str, name: str = None, + ssl: bool = False, session_id: Optional[str] = None): + self.client: 'Client' = manager.client + self.manager: 'NodeManager' = manager + self._transport = Transport(self, host, port, password, ssl, session_id) + self.region: str = region - self.name: str = name or '{}-{}:{}'.format(self.region, self.host, self.port) - self.filters: bool = filters + self.name: str = name or f'{region}-{host}:{port}' self.stats: Stats = Stats.empty(self) @property def available(self) -> bool: - """ Returns whether the node is available for requests. """ - return self._ws.connected + """ + Returns whether the node is available for requests. + + .. deprecated:: 5.0.0 + As of Lavalink server 4.0.0, a WebSocket connection is no longer required to operate a + node. As a result, this property is no longer considered useful. + """ + return True @property def _original_players(self) -> List[BasePlayer]: @@ -88,7 +87,7 @@ def _original_players(self) -> List[BasePlayer]: ------- List[:class:`BasePlayer`] """ - return [p for p in self._lavalink.player_manager.values() if p._original_node == self] + return [p for p in self.client.player_manager.values() if p._original_node == self] @property def players(self) -> List[BasePlayer]: @@ -99,29 +98,25 @@ def players(self) -> List[BasePlayer]: ------- List[:class:`BasePlayer`] """ - return [p for p in self._lavalink.player_manager.values() if p.node == self] + return [p for p in self.client.player_manager.values() if p.node == self] @property - def penalty(self) -> int: + def penalty(self) -> float: """ Returns the load-balancing penalty for this node. """ if not self.available or not self.stats: return 9e30 return self.stats.penalty.total - @property - def http_uri(self) -> str: - """ Returns a 'base' URI pointing to the node's address and port, also factoring in SSL. """ - return '{}://{}:{}'.format('https' if self.ssl else 'http', self.host, self.port) - async def destroy(self): """|coro| - Closes the WebSocket connection for this node. No further connection attempts will be made. + Destroys the transport and any underlying connections for this node. + This will also cleanly close the websocket. """ - await self._ws.destroy() + await self._transport.destroy() - async def get_tracks(self, query: str, check_local: bool = False) -> LoadResult: + async def get_tracks(self, query: str) -> LoadResult: """|coro| Retrieves a list of results pertaining to the provided query. @@ -130,27 +125,58 @@ async def get_tracks(self, query: str, check_local: bool = False) -> LoadResult: ---------- query: :class:`str` The query to perform a search for. - check_local: :class:`bool` - Whether to also search the query on sources registered with this Lavalink client. Returns ------- :class:`LoadResult` """ - return await self._lavalink.get_tracks(query, self, check_local) + return await self._transport._request('GET', 'loadtracks', params={'identifier': query}, to=LoadResult) + + async def decode_track(self, track: str) -> AudioTrack: + """|coro| + + Decodes a base64-encoded track string into an :class:`AudioTrack` object. + + Parameters + ---------- + track: :class:`str` + The base64-encoded track string to decode. + + Returns + ------- + :class:`AudioTrack` + """ + return await self._transport._request('GET', 'decodetrack', params={'track': track}, to=AudioTrack) + + async def decode_tracks(self, tracks: List[str]) -> List[AudioTrack]: + """|coro| + + Decodes a list of base64-encoded track strings into a list of :class:`AudioTrack`. + + Parameters + ---------- + tracks: List[:class:`str`] + A list of base64-encoded ``track`` strings. + + Returns + ------- + List[:class:`AudioTrack`] + A list of decoded AudioTracks. + """ + response = await self._transport._request('POST', 'decodetracks', json=tracks) + return list(map(AudioTrack, response)) - async def routeplanner_status(self): + async def get_routeplanner_status(self) -> Dict[str, Any]: """|coro| Retrieves the status of the target node's routeplanner. Returns ------- - :class:`dict` + Dict[str, Any] A dict representing the routeplanner information. """ - return await self._lavalink._get_request('{}/routeplanner/status'.format(self.http_uri), - headers={'Authorization': self.password}) + return await self._transport._request('GET', 'routeplanner/status') async def routeplanner_free_address(self, address: str) -> bool: """|coro| @@ -167,9 +193,10 @@ async def routeplanner_free_address(self, address: str) -> bool: :class:`bool` True if the address was freed, False otherwise. """ - return await self._lavalink._post_request('{}/routeplanner/free/address'.format(self.http_uri), - json={'address': address}, - headers={'Authorization': self.password}) + try: + return await self._transport._request('POST', 'routeplanner/free/address', json={'address': address}) + except RequestError: + return False async def routeplanner_free_all_failing(self) -> bool: """|coro| @@ -181,46 +208,336 @@ async def routeplanner_free_all_failing(self) -> bool: :class:`bool` True if all failing addresses were freed, False otherwise. """ - return await self._lavalink._post_request('{}/routeplanner/free/all'.format(self.http_uri), - headers={'Authorization': self.password}) + try: + return await self._transport._request('POST', 'routeplanner/free/all') + except RequestError: + return False + + async def get_info(self) -> Dict[str, Any]: + """|coro| + + Retrieves information about this node. + + Returns + ------- + Dict[str, Any] + A raw response containing information about the node. + """ + return await self._transport._request('GET', 'info') - async def get_plugins(self) -> List[Plugin]: + async def get_stats(self) -> Dict[str, Any]: """|coro| - Retrieves a list of plugins active on this node. + Retrieves statistics about this node. Returns ------- - List[:class:`Plugin`] - A list of active plugins. + Dict[str, Any] + A raw response containing information about the node. """ - data = await self._lavalink._get_request('{}/plugins'.format(self.http_uri), - headers={'Authorization': self.password}) - return [Plugin(plugin) for plugin in data] + return await self._transport._request('GET', 'stats') - async def _dispatch_event(self, event: Event): + async def get_version(self) -> str: """|coro| - Dispatches the given event to all registered hooks. + Retrieves the version of this node. + + Returns + ------- + str + The version of this Lavalink server. + """ + return await self._transport._request('GET', 'version', to=str, versioned=False) + + async def get_player(self, guild_id: Union[str, int]) -> Dict[str, Any]: + """|coro| + + Retrieves a player from the node. + This returns raw data, to retrieve a player you can interact with, use :meth:`PlayerManager.get`. + + Returns + ------- + Dict[str, Any] + A raw player object. + """ + session_id = self._transport.session_id + + if not session_id: + raise ClientError('Cannot retrieve a player without a valid session ID!') + + return await self._transport._request('GET', f'sessions/{session_id}/players/{guild_id}') + + async def get_players(self) -> List[Dict[str, Any]]: + """|coro| + + Retrieves a list of players from the node. + This returns raw data, to retrieve players you can interact with, use :attr:`players`. + + Returns + ------- + List[Dict[str, Any]] + A list of raw player objects. + """ + session_id = self._transport.session_id + + if not session_id: + raise ClientError('Cannot retrieve a list of players without a valid session ID!') + + return await self._transport._request('GET', f'sessions/{session_id}/players') + + @overload + async def update_player(self, + guild_id: Union[str, int], + encoded_track: Optional[str] = ..., + no_replace: bool = ..., + position: int = ..., + end_time: int = ..., + volume: int = ..., + paused: bool = ..., + filters: Optional[List[Filter]] = ..., + voice_state: Dict[str, Any] = ..., + user_data: Optional[Dict[str, Any]] = ..., + **kwargs) -> Dict[str, Any]: + ... + + @overload + async def update_player(self, + guild_id: Union[str, int], + identifier: str = ..., + no_replace: bool = ..., + position: int = ..., + end_time: int = ..., + volume: int = ..., + paused: bool = ..., + filters: Optional[List[Filter]] = ..., + voice_state: Dict[str, Any] = ..., + user_data: Dict[str, Any] = ..., + **kwargs) -> Dict[str, Any]: + ... + + @overload + async def update_player(self, + guild_id: Union[str, int], + no_replace: bool = ..., + position: int = ..., + end_time: int = ..., + volume: int = ..., + paused: bool = ..., + filters: Optional[List[Filter]] = ..., + voice_state: Dict[str, Any] = ..., + user_data: Dict[str, Any] = ..., + **kwargs) -> Dict[str, Any]: + ... + + async def update_player(self, # pylint: disable=too-many-locals + guild_id: Union[str, int], + encoded_track: Optional[str] = MISSING, + identifier: str = MISSING, + no_replace: bool = MISSING, + position: int = MISSING, + end_time: int = MISSING, + volume: int = MISSING, + paused: bool = MISSING, + filters: Optional[List[Filter]] = MISSING, + voice_state: Dict[str, Any] = MISSING, + user_data: Dict[str, Any] = MISSING, + **kwargs) -> Dict[str, Any]: + """|coro| + + .. _response object: https://lavalink.dev/api/rest#Player + + Update the state of a player. + + Warning + ------- + If this function is called directly, rather than through, e.g. a player, + the internal state is not guaranteed! This means that any attributes accessible through other classes + may not correspond with those stored in, or provided by the server. Use with caution! Parameters ---------- - event: :class:`Event` - The event to dispatch to the hooks. + guild_id: Union[str, int] + The guild ID of the player to update. + encoded_track: Optional[str] + The base64-encoded track string to play. + You may provide ``None`` to stop the player. + + Warning + ------- + This option is mutually exclusive with ``identifier``. You cannot provide both options. + identifier: str + The identifier of the track to play. This can be a track ID or URL. It may not be a + search query or playlist link. If it yields a search, playlist, or no track, a :class:`RequestError` + will be raised. + + Warning + ------- + This option is mutually exclusive with ``encoded_track``. You cannot provide both options. + no_replace: bool + Whether to replace the currently playing track (if one exists) with the new track. + Only takes effect if ``identifier`` or ``encoded_track`` is provided. + This parameter will only take effect when a track is provided. + position: int + The track position in milliseconds. This can be used to seek. + end_time: int + The position, in milliseconds, to end the track at. + volume: int + The new volume of the player. This must be within the range of 0 to 1000. + paused: bool + Whether to pause the player. + filters: Optional[List[:class:`Filter`]] + The filters to apply to the player. + Specify ``None`` or ``[]`` to clear. + voice_state: Dict[str, Any] + The new voice state of the player. + user_data: Dict[str, Any] + The user data to attach to the track, if one is provided. + This parameter will only take effect when a track is provided. + **kwargs: Any + The kwargs to use when updating the player. You can specify any extra parameters that may be + used by plugins, which offer extra features not supported out-of-the-box by Lavalink.py. + + Returns + ------- + Dict[str, Any] + The raw player update `response object`_. """ - await self._lavalink._dispatch_event(event) + session_id = self._transport.session_id + + if not session_id: + raise ClientError('Cannot update the state of a player without a valid session ID!') + + if encoded_track is not MISSING and identifier is not MISSING: + raise ValueError('encoded_track and identifier are mutually exclusive options, you may not specify both together.') + + params = {} + json = kwargs + + if identifier is not MISSING or encoded_track is not MISSING: + track = {} + + if identifier is not MISSING: + track['identifier'] = identifier + elif encoded_track is not MISSING: + track['encoded'] = encoded_track + + if user_data is not MISSING: + track['userData'] = user_data + + if no_replace is not MISSING: + params['noReplace'] = str(no_replace).lower() + + json['track'] = track + + if position is not MISSING: + if not isinstance(position, (int, float)): + raise ValueError('position must be an int!') + + json['position'] = position + + if end_time is not MISSING: + if not isinstance(end_time, int) or end_time <= 0: + raise ValueError('end_time must be an int, and greater than 0!') + + json['endTime'] = end_time + + if volume is not MISSING: + if not isinstance(volume, int) or not 0 <= volume <= 1000: + raise ValueError('volume must be an int, and within the range of 0 to 1000!') + + json['volume'] = volume + + if paused is not MISSING: + if not isinstance(paused, bool): + raise ValueError('paused must be a bool!') + + json['paused'] = paused + + if filters is not MISSING: + if filters is not None: + if not isinstance(filters, list) or not all(isinstance(f, Filter) for f in filters): + raise ValueError('filters must be a list of Filter!') + + serialized = defaultdict(dict) - async def _send(self, **data): + for filter_ in filters: + filter_obj = serialized['pluginFilters'] if filter_.plugin_filter else serialized + filter_obj.update(filter_.serialize()) + + json['filters'] = serialized + else: + json['filters'] = {} + + if voice_state is not MISSING: + if not isinstance(voice_state, dict): + raise ValueError('voice_state must be a dict!') + + json['voice'] = voice_state + + if not json: + return + + return await self._transport._request('PATCH', f'sessions/{session_id}/players/{guild_id}', + params=params, json=json) + + async def destroy_player(self, guild_id: Union[str, int]) -> bool: + """|coro| + + Destroys a player on the node. + It's recommended that you use :meth:`PlayerManager.destroy` to destroy a player. + + Returns + ------- + bool + Whether the player was destroyed. + """ + session_id = self._transport.session_id + + if not session_id: + raise ClientError('Cannot destroy a player without a valid session ID!') + + return await self._transport._request('DELETE', f'sessions/{session_id}/players/{guild_id}') + + async def update_session(self, resuming: bool = MISSING, timeout: int = MISSING) -> Dict[str, Any]: """|coro| - Sends the passed data to the node via the websocket connection. + Update the session for this node. Parameters ---------- - data: class:`any` - The dict to send to Lavalink. + resuming: bool + Whether to enable resuming for this session. + timeout: int + How long the node will wait for the session to resume before destroying it, in seconds. + + Returns + ------- + Dict[str, Any] + A raw response from the node containing the current session configuration. """ - await self._ws._send(**data) + session_id = self._transport.session_id + + if not session_id: + raise ClientError('Cannot update a session without a valid session ID!') + + json = {} + + if resuming is not MISSING: + if not isinstance(resuming, bool): + raise ValueError('resuming must be a bool!') + + json['resuming'] = resuming + + if timeout is not MISSING: + if not isinstance(timeout, int) or 0 >= timeout: + raise ValueError('timeout must be an int greater than 0!') + + json['timeout'] = timeout + + if not json: + return + + return await self._transport._request('PATCH', f'sessions/{session_id}', json=json) def __repr__(self): - return ''.format(self) + return f'' diff --git a/lavalink/nodemanager.py b/lavalink/nodemanager.py index 1b1ad080..17b8a59f 100644 --- a/lavalink/nodemanager.py +++ b/lavalink/nodemanager.py @@ -22,11 +22,14 @@ SOFTWARE. """ import logging -from typing import List, Optional +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple -from .events import NodeConnectedEvent, NodeDisconnectedEvent +from .errors import ClientError from .node import Node +if TYPE_CHECKING: + from .client import Client + _log = logging.getLogger(__name__) DEFAULT_REGIONS = { 'asia': ('hongkong', 'singapore', 'sydney', 'japan', 'southafrica', 'india'), @@ -45,32 +48,43 @@ class NodeManager: Attributes ---------- - nodes: :class:`list` + client: :class:`Client` + The Lavalink client. + nodes: List[:class:`Node`] Cache of all the nodes that Lavalink has created. - regions: :class:`dict` + regions: Dict[str, Tuple[str]] A mapping of continent -> Discord RTC regions. """ - def __init__(self, lavalink, regions: dict): - self._lavalink = lavalink + __slots__ = ('_player_queue', '_connect_back', 'client', 'nodes', 'regions') + + def __init__(self, client, regions: Dict[str, Tuple[str]], connect_back: bool): self._player_queue = [] - self.nodes = [] - self.regions = regions or DEFAULT_REGIONS + self._connect_back: bool = connect_back + self.client: 'Client' = client + self.nodes: List[Node] = [] + self.regions: Dict[str, Tuple[str]] = regions or DEFAULT_REGIONS - def __len__(self): + def __len__(self) -> int: return len(self.nodes) - def __iter__(self): - for n in self.nodes: - yield n + def __iter__(self) -> Iterator[Node]: + for node in self.nodes: + yield node @property def available_nodes(self) -> List[Node]: - """ Returns a list of available nodes. """ + """ + Returns a list of available nodes. + + .. deprecated:: 5.0.0 + As of Lavalink server 4.0.0, a WebSocket connection is no longer required to operate a + node. As a result, this property is no longer considered useful as all nodes are considered + available. + """ return [n for n in self.nodes if n.available] - def add_node(self, host: str, port: int, password: str, region: str, - resume_key: str = None, resume_timeout: int = 60, name: str = None, - reconnect_attempts: int = 3, filters: bool = False, ssl: bool = False): + def add_node(self, host: str, port: int, password: str, region: str, name: str = None, + ssl: bool = False, session_id: Optional[str] = None) -> Node: """ Adds a node to Lavalink's node manager. @@ -84,37 +98,33 @@ def add_node(self, host: str, port: int, password: str, region: str, The password used for authentication. region: :class:`str` The region to assign this node to. - resume_key: Optional[:class:`str`] - A resume key used for resuming a session upon re-establishing a WebSocket connection to Lavalink. - Defaults to ``None``. - resume_timeout: Optional[:class:`int`] - How long the node should wait for a connection while disconnected before clearing all players. - Defaults to ``60``. name: Optional[:class:`str`] An identifier for the node that will show in logs. Defaults to ``None``. reconnect_attempts: Optional[:class:`int`] The amount of times connection with the node will be reattempted before giving up. Set to `-1` for infinite. Defaults to ``3``. - filters: Optional[:class:`bool`] - Whether to use the new ``filters`` op. This setting currently only applies to development - Lavalink builds, where the ``equalizer`` op was swapped out for the broader ``filters`` op which - offers more than just equalizer functionality. Ideally, you should only change this setting if you - know what you're doing, as this can prevent the effects from working. ssl: Optional[:class:`bool`] Whether to use SSL for the node. SSL will use ``wss`` and ``https``, instead of ``ws`` and ``http``, respectively. Your node should support SSL if you intend to enable this, either via reverse proxy or other methods. Only enable this if you know what you're doing. + session_id: Optional[:class:`str`] + The ID of the session to resume. Defaults to ``None``. + Only specify this if you have the ID of the session you want to resume. + + Returns + ------- + :class:`Node` + The created Node instance. """ - node = Node(self, host, port, password, region, resume_key, resume_timeout, name, reconnect_attempts, filters, ssl) + node = Node(self, host, port, password, region, name, ssl, session_id) self.nodes.append(node) - - _log.info('Added node \'%s\'', node.name) + return node def remove_node(self, node: Node): """ Removes a node. - Make sure you have called :func:`Node.destroy` to close any open WebSocket connection. + Make sure you have called :func:`Node.destroy` to close any resources used by this Node. Parameters ---------- @@ -122,7 +132,6 @@ def remove_node(self, node: Node): The node to remove from the list. """ self.nodes.remove(node) - _log.info('Removed node \'%s\'', node.name) def get_nodes_by_region(self, region_key: str) -> List[Node]: """ @@ -171,7 +180,7 @@ def get_region(self, endpoint: str) -> str: return None - def find_ideal_node(self, region: str = None) -> Optional[Node]: + def find_ideal_node(self, region: str = None, exclude: Optional[List[Node]] = None) -> Optional[Node]: """ Finds the best (least used) node in the given region, if applicable. @@ -179,17 +188,21 @@ def find_ideal_node(self, region: str = None) -> Optional[Node]: ---------- region: Optional[:class:`str`] The region to find a node in. Defaults to ``None``. + exclude: Optional[List[:class:`Node`]] + A list of nodes to exclude from the choice. Returns ------- Optional[:class:`Node`] """ + exclusions = exclude or [] nodes = None + if region: - nodes = [n for n in self.available_nodes if n.region == region] + nodes = [n for n in self.available_nodes if n.region == region and n not in exclusions] if not nodes: # If there are no regional nodes available, or a region wasn't specified. - nodes = self.available_nodes + nodes = [n for n in self.available_nodes if n not in exclusions] if not nodes: return None @@ -197,40 +210,28 @@ def find_ideal_node(self, region: str = None) -> Optional[Node]: best_node = min(nodes, key=lambda node: node.penalty) return best_node - async def _node_connect(self, node: Node): - """ - Called when a node is connected from Lavalink. - - Parameters - ---------- - node: :class:`Node` - The node that has just connected. - """ + async def _handle_node_ready(self, node: Node): for player in self._player_queue: await player.change_node(node) original_node_name = player._original_node.name if player._original_node else '[no node]' _log.debug('Moved player %d from node \'%s\' to node \'%s\'', player.guild_id, original_node_name, node.name) - if self._lavalink._connect_back: + if self._connect_back: for player in node._original_players: await player.change_node(node) player._original_node = None self._player_queue.clear() - await self._lavalink._dispatch_event(NodeConnectedEvent(node)) - async def _node_disconnect(self, node: Node, code: int, reason: str): - """ + async def _handle_node_disconnect(self, node: Node): + """|coro| + Called when a node is disconnected from Lavalink. Parameters ---------- node: :class:`Node` The node that has just connected. - code: :class:`int` - The code for why the node was disconnected. - reason: :class:`str` - The reason why the node was disconnected. """ for player in node.players: try: @@ -238,17 +239,20 @@ async def _node_disconnect(self, node: Node, code: int, reason: str): except: # noqa: E722 pylint: disable=bare-except _log.exception('An error occurred whilst calling player.node_unavailable()') - await self._lavalink._dispatch_event(NodeDisconnectedEvent(node, code, reason)) - - best_node = self.find_ideal_node(node.region) + best_node = self.find_ideal_node(node.region, exclude=[node]) # Don't use the node these players are moving from. if not best_node: self._player_queue.extend(node.players) - _log.error('Unable to move players, no available nodes! Waiting for a node to become available.') + _log.warning('Unable to move players, no available nodes! Waiting for a node to become available.') return + # TODO: This may need reinvestigating to make it more robust with the lack of WS requirement. + # i.e. we need a way to determine whether nodes are "reachable". for player in node.players: - await player.change_node(best_node) + try: + await player.change_node(best_node) - if self._lavalink._connect_back: - player._original_node = node + if self._connect_back: + player._original_node = node + except ClientError: + _log.error('Failed to move player %d from node \'%s\' to new node \'%s\'', player.guild_id, node.name, best_node.name) diff --git a/lavalink/player.py b/lavalink/player.py new file mode 100644 index 00000000..f0574200 --- /dev/null +++ b/lavalink/player.py @@ -0,0 +1,652 @@ +""" +MIT License + +Copyright (c) 2017-present Devoxin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import logging +from random import randrange +from time import time +from typing import (TYPE_CHECKING, Dict, List, Optional, Type, # Literal + TypeVar, Union) + +from .abc import BasePlayer, DeferredAudioTrack +from .common import MISSING +from .errors import PlayerErrorEvent, RequestError +from .events import (NodeChangedEvent, QueueEndEvent, TrackEndEvent, + TrackStuckEvent) +from .filters import Filter +from .server import AudioTrack + +if TYPE_CHECKING: + from .node import Node + +_log = logging.getLogger(__name__) + +FilterT = TypeVar('FilterT', bound=Filter) + + +class DefaultPlayer(BasePlayer): + """ + The player that Lavalink.py uses by default. + + This should be sufficient for most use-cases. + + Attributes + ---------- + LOOP_NONE: :class:`int` + Class attribute. Disables looping entirely. + LOOP_SINGLE: :class:`int` + Class attribute. Enables looping for a single (usually currently playing) track only. + LOOP_QUEUE: :class:`int` + Class attribute. Enables looping for the entire queue. When a track finishes playing, it'll be added to the end of the queue. + + guild_id: :class:`int` + The guild id of the player. + node: :class:`Node` + The node that the player is connected to. + paused: :class:`bool` + Whether or not a player is paused. + position_timestamp: :class:`int` + Returns the track's elapsed playback time as an epoch timestamp. + volume: :class:`int` + The volume at which the player is playing at. + shuffle: :class:`bool` + Whether or not to mix the queue up in a random playing order. + loop: Literal[0, 1, 2] + Whether loop is enabled, and the type of looping. + This is an integer as loop supports multiple states. + + 0 = Loop off. + + 1 = Loop track. + + 2 = Loop queue. + + Example + ------- + .. code:: python + + if player.loop == player.LOOP_NONE: + await ctx.send('Not looping.') + elif player.loop == player.LOOP_SINGLE: + await ctx.send(f'{player.current.title} is looping.') + elif player.loop == player.LOOP_QUEUE: + await ctx.send('This queue never ends!') + filters: Dict[:class:`str`, :class:`Filter`] + A mapping of str to :class:`Filter`, representing currently active filters. + queue: List[:class:`AudioTrack`] + A list of AudioTracks to play. + current: Optional[:class:`AudioTrack`] + The track that is playing currently, if any. + """ + LOOP_NONE: int = 0 + LOOP_SINGLE: int = 1 + LOOP_QUEUE: int = 2 + + def __init__(self, guild_id: int, node: 'Node'): + super().__init__(guild_id, node) + + self._user_data = {} + + self.paused: bool = False + self._internal_pause: bool = False # Toggled when player's node becomes unavailable, primarily used for track position tracking. + self._last_update: int = 0 + self._last_position: int = 0 + self.position_timestamp: int = 0 + self.volume: int = 100 + self.shuffle: bool = False + self.loop: int = 0 # 0 = off, 1 = single track, 2 = queue + self.filters: Dict[str, Filter] = {} + self.queue: List[AudioTrack] = [] + + @property + def is_playing(self) -> bool: + """ Returns the player's track state. """ + return self.is_connected and self.current is not None + + @property + def is_connected(self) -> bool: + """ Returns whether the player is connected to a voicechannel or not. """ + return self.channel_id is not None + + @property + def position(self) -> int: + """ Returns the track's elapsed playback time in milliseconds, adjusted for Lavalink stat interval. """ + if not self.is_playing: + return 0 + + if self.paused or self._internal_pause: + return min(self._last_position, self.current.duration) + + difference = int(time() * 1000) - self._last_update + return min(self._last_position + difference, self.current.duration) + + def store(self, key: object, value: object): + """ + Stores custom user data. + + Parameters + ---------- + key: :class:`object` + The key of the object to store. + value: :class:`object` + The object to associate with the key. + """ + self._user_data.update({key: value}) + + def fetch(self, key: object, default=None): + """ + Retrieves the related value from the stored user data. + + Parameters + ---------- + key: :class:`object` + The key to fetch. + default: Optional[:class:`any`] + The object that should be returned if the key doesn't exist. Defaults to ``None``. + + Returns + ------- + Optional[:class:`any`] + """ + return self._user_data.get(key, default) + + def delete(self, key: object): + """ + Removes an item from the the stored user data. + + Parameters + ---------- + key: :class:`object` + The key to delete. + + Raises + ------ + :class:`KeyError` + If the key doesn't exist. + """ + try: + del self._user_data[key] + except KeyError: + pass + + def add(self, track: Union[AudioTrack, 'DeferredAudioTrack', Dict[str, Union[Optional[str], bool, int]]], + requester: int = 0, index: int = None): + """ + Adds a track to the queue. + + Parameters + ---------- + track: Union[:class:`AudioTrack`, :class:`DeferredAudioTrack`, Dict[str, Union[Optional[str], bool, int]]] + The track to add. Accepts either an AudioTrack or + a dict representing a track returned from Lavalink. + requester: :class:`int` + The ID of the user who requested the track. + index: Optional[:class:`int`] + The index at which to add the track. + If index is left unspecified, the default behaviour is to append the track. Defaults to ``None``. + """ + track = AudioTrack(track, requester) if isinstance(track, dict) else track + + if requester != 0: + track.requester = requester + + if index is None: + self.queue.append(track) + else: + self.queue.insert(index, track) + + async def play(self, + track: Optional[Union[AudioTrack, 'DeferredAudioTrack', Dict[str, Union[Optional[str], bool, int]]]] = None, + start_time: int = 0, + end_time: int = MISSING, + no_replace: bool = MISSING, + volume: int = MISSING, + pause: bool = False, + **kwargs): + """|coro| + + Plays the given track. + + This method differs from :meth:`BasePlayer.play_track` in that it contains additional logic + to handle certain attributes, such as ``loop``, ``shuffle``, and loading a base64 string from :class:`DeferredAudioTrack`. + + :meth:`BasePlayer.play_track` is a no-frills, raw function which will unconditionally tell the node to play exactly whatever + it is passed. + + Parameters + ---------- + track: Optional[Union[:class:`DeferredAudioTrack`, :class:`AudioTrack`, Dict[str, Union[Optional[str], bool, int]]]] + The track to play. If left unspecified, this will default to the first track in the queue. Defaults to ``None`` + which plays the next song in queue. Accepts either an AudioTrack or a dict representing a track + returned from Lavalink. + start_time: :class:`int` + The number of milliseconds to offset the track by. + If left unspecified, the track will start from the beginning. + end_time: :class:`int` + The position at which the track should stop playing. + This is an absolute position, so if you want the track to stop at 1 minute, you would pass 60000. + If left unspecified, the track will play through to the end. + no_replace: :class:`bool` + If set to true, operation will be ignored if the player already has a current track. + If left unspecified, the currently playing track will always be replaced. + volume: :class:`int` + The initial volume to set. This is useful for changing the volume between tracks etc. + If left unspecified, the volume will remain at its current setting. + pause: :class:`bool` + Whether to immediately pause the track after loading it. Defaults to ``False``. + **kwargs: Any + The kwargs to use when playing. You can specify any extra parameters that may be + used by plugins, which offer extra features not supported out-of-the-box by Lavalink.py. + + Raises + ------ + :class:`ValueError` + If invalid values were provided for ``start_time`` or ``end_time``. + :class:`TypeError` + If wrong types were provided for ``no_replace``, ``volume`` or ``pause``. + """ + if isinstance(no_replace, bool) and no_replace and self.is_playing: + return + + if track is not None and isinstance(track, dict): + track = AudioTrack(track, 0) + + if self.loop > 0 and self.current: + if self.loop == 1: + if track is not None: + self.queue.insert(0, self.current) + else: + track = self.current + elif self.loop == 2: + self.queue.append(self.current) + + self._last_position = 0 + self.position_timestamp = 0 + self.paused = pause + + if not track: + if not self.queue: + await self.stop() # Also sets current to None. + await self.client._dispatch_event(QueueEndEvent(self)) + return + + pop_at = randrange(len(self.queue)) if self.shuffle else 0 + track = self.queue.pop(pop_at) + + if start_time is not MISSING: + if not isinstance(start_time, int) or not 0 <= start_time < track.duration: + raise ValueError('start_time must be an int with a value equal to, or greater than 0, and less than the track duration') + + if end_time is not MISSING: + if not isinstance(end_time, int) or not 1 <= end_time <= track.duration: + raise ValueError('end_time must be an int with a value equal to, or greater than 1, and less than, or equal to the track duration') + + await self.play_track(track, start_time, end_time, no_replace, volume, pause, **kwargs) + + async def stop(self): + """|coro| + + Stops the player. + """ + await self.node.update_player(self._internal_id, encoded_track=None) + self.current = None + + async def skip(self): + """|coro| + + Plays the next track in the queue, if any. + """ + await self.play() + + def set_loop(self, loop: int): + """ + Sets whether the player loops between a single track, queue or none. + + 0 = off, 1 = single track, 2 = queue. + + Parameters + ---------- + loop: Literal[0, 1, 2] + The loop setting. 0 = off, 1 = single track, 2 = queue. + """ + if not 0 <= loop <= 2: + raise ValueError('Loop must be 0, 1 or 2.') + + self.loop = loop + + def set_shuffle(self, shuffle: bool): + """ + Sets the player's shuffle state. + + Parameters + ---------- + shuffle: :class:`bool` + Whether to shuffle the player or not. + """ + self.shuffle = shuffle + + async def set_pause(self, pause: bool): + """|coro| + + Sets the player's paused state. + + Parameters + ---------- + pause: :class:`bool` + Whether to pause the player or not. + """ + await self.node.update_player(self._internal_id, paused=pause) + self.paused = pause + + async def set_volume(self, vol: int): + """|coro| + + Sets the player's volume + + Note + ---- + A limit of 1000 is imposed by Lavalink. + + Parameters + ---------- + vol: :class:`int` + The new volume level. + """ + vol = max(min(vol, 1000), 0) + await self.node.update_player(self._internal_id, volume=vol) + self.volume = vol + + async def seek(self, position: int): + """|coro| + + Seeks to a given position in the track. + + Parameters + ---------- + position: :class:`int` + The new position to seek to in milliseconds. + """ + if not isinstance(position, int): + raise ValueError('position must be an int!') + + await self.node.update_player(self._internal_id, position=position) + + async def set_filters(self, *filters: FilterT): + """|coro| + + This sets multiple filters at once. + + Applies the corresponding filters within Lavalink. + This will overwrite any identical filters that are already applied. + + Parameters + ---------- + *filters: :class:`Filter` + The filters to apply. + + Raises + ------ + :class:`TypeError` + If any of the provided filters is not of type :class:`Filter`. + """ + for _filter in filters: + if not isinstance(_filter, Filter): + raise TypeError(f'Expected object of type Filter, not {type(_filter).__name__}') + + filter_name = type(_filter).__name__.lower() + self.filters[filter_name] = _filter + + await self._apply_filters() + + async def set_filter(self, _filter: FilterT): + """|coro| + + Applies the corresponding filter within Lavalink. + This will overwrite the filter if it's already applied. + + Example + ------- + .. code:: python + + equalizer = Equalizer() + equalizer.update(bands=[(0, 0.2), (1, 0.3), (2, 0.17)]) + player.set_filter(equalizer) + + Parameters + ---------- + _filter: :class:`Filter` + The filter instance to set. + + Raises + ------ + :class:`TypeError` + If the provided ``_filter`` is not of type :class:`Filter`. + """ + if not isinstance(_filter, Filter): + raise TypeError(f'Expected object of type Filter, not {type(_filter).__name__}') + + filter_name = type(_filter).__name__.lower() + self.filters[filter_name] = _filter + await self._apply_filters() + + async def update_filter(self, _filter: Type[FilterT], **kwargs): + """|coro| + + Updates a filter using the upsert method; + if the filter exists within the player, its values will be updated; + if the filter does not exist, it will be created with the provided values. + + This will not overwrite any values that have not been provided. + + Example + ------- + .. code :: python + + player.update_filter(Timescale, speed=1.5) + # This means that, if the Timescale filter is already applied + # and it already has set values of "speed=1, pitch=1.2", pitch will remain + # the same, however speed will be changed to 1.5 so the result is + # "speed=1.5, pitch=1.2" + + Parameters + ---------- + _filter: Type[:class:`Filter`] + The filter class (**not** an instance of, see above example) to upsert. + **kwargs: Any + The kwargs to pass to the filter. + + Raises + ------ + :class:`TypeError` + If the provided ``_filter`` is not of type :class:`Filter`. + """ + if isinstance(_filter, Filter): + raise TypeError(f'Expected class of type Filter, not an instance of {type(_filter).__name__}') + + if not issubclass(_filter, Filter): + raise TypeError(f'Expected subclass of type Filter, not {_filter.__name__}') + + filter_name = _filter.__name__.lower() + + filter_instance = self.filters.get(filter_name, _filter()) + filter_instance.update(**kwargs) + self.filters[filter_name] = filter_instance + await self._apply_filters() + + def get_filter(self, _filter: Union[Type[FilterT], str]): + """ + Returns the corresponding filter, if it's enabled. + + Example + ------- + .. code:: python + + from lavalink.filters import Timescale + timescale = player.get_filter(Timescale) + # or + timescale = player.get_filter('timescale') + + Parameters + ---------- + _filter: Union[Type[:class:`Filter`], :class:`str`] + The filter name, or filter class (**not** an instance of, see above example), to get. + + Returns + ------- + Optional[:class:`Filter`] + """ + if isinstance(_filter, str): + filter_name = _filter + elif isinstance(_filter, Filter): # User passed an instance of. + filter_name = type(_filter).__name__ + else: + if not issubclass(_filter, Filter): + raise TypeError(f'Expected subclass of type Filter, not {_filter.__name__}') + + filter_name = _filter.__name__ + + return self.filters.get(filter_name.lower(), None) + + async def remove_filter(self, _filter: Union[Type[FilterT], str]): + """|coro| + + Removes a filter from the player, undoing any effects applied to the audio. + + Example + ------- + .. code:: python + + player.remove_filter(Timescale) + # or + player.remove_filter('timescale') + + Parameters + ---------- + _filter: Union[Type[:class:`Filter`], :class:`str`] + The filter name, or filter class (**not** an instance of, see above example), to remove. + """ + if isinstance(_filter, str): + filter_name = _filter + elif isinstance(_filter, Filter): # User passed an instance of. + filter_name = type(_filter).__name__ + else: + if not issubclass(_filter, Filter): + raise TypeError(f'Expected subclass of type Filter, not {_filter.__name__}') + + filter_name = _filter.__name__ + + fn_lowered = filter_name.lower() + + if fn_lowered in self.filters: + self.filters.pop(fn_lowered) + await self._apply_filters() + + async def clear_filters(self): + """|coro| + + Clears all currently-enabled filters. + """ + self.filters.clear() + await self._apply_filters() + + async def _apply_filters(self): + await self.node.update_player(self._internal_id, filters=list(self.filters.values())) + + async def _handle_event(self, event): + """ + Handles the given event as necessary. + + Parameters + ---------- + event: :class:`Event` + The event that will be handled. + """ + # A track throws loadFailed when it fails to provide any audio before throwing an exception. + # A TrackStuckEvent is not proceeded by a TrackEndEvent. In theory, you could ignore a TrackStuckEvent + # and hope that a track will eventually play, however, it's unlikely. + + if isinstance(event, TrackStuckEvent) or isinstance(event, TrackEndEvent) and event.reason.may_start_next(): + try: + await self.play() + except RequestError as error: + await self.client._dispatch_event(PlayerErrorEvent(self, error)) + _log.exception('[DefaultPlayer:%d] Encountered a request error whilst starting a new track.', self.guild_id) + + async def _update_state(self, state: dict): + """ + Updates the position of the player. + + Parameters + ---------- + state: :class:`dict` + The state that is given to update. + """ + self._last_update = int(time() * 1000) + self._last_position = state.get('position', 0) + self.position_timestamp = state.get('time', 0) + + async def node_unavailable(self): + """|coro| + + Called when a player's node becomes unavailable. + Useful for changing player state before it's moved to another node. + """ + self._internal_pause = True + + async def change_node(self, node: 'Node'): + """|coro| + + Changes the player's node + + Parameters + ---------- + node: :class:`Node` + The node the player is changed to. + """ + if self.node.available: + await self.node.destroy_player(self._internal_id) + + old_node = self.node + self.node = node + + if self._voice_state: + await self._dispatch_voice_update() + + if self.current: + playable_track = self.current.track + + if isinstance(self.current, DeferredAudioTrack) and playable_track is None: + playable_track = await self.current.load(self.client) + + await self.node.update_player(self._internal_id, encoded_track=playable_track, position=self.position, + paused=self.paused, volume=self.volume) + self._last_update = int(time() * 1000) + + self._internal_pause = False + + if self.filters: + await self._apply_filters() + + await self.client._dispatch_event(NodeChangedEvent(self, old_node, node)) + + def __repr__(self): + return f'' diff --git a/lavalink/playermanager.py b/lavalink/playermanager.py index c63c8bb0..8d0f4a71 100644 --- a/lavalink/playermanager.py +++ b/lavalink/playermanager.py @@ -22,14 +22,20 @@ SOFTWARE. """ import logging -from typing import Callable, Dict, Iterator +from typing import (TYPE_CHECKING, Callable, Dict, Iterator, Optional, Tuple, + Type, TypeVar) -from .errors import NodeError -from .models import BasePlayer +from .errors import ClientError from .node import Node +from .player import BasePlayer + +if TYPE_CHECKING: + from .client import Client _log = logging.getLogger(__name__) +PlayerT = TypeVar('PlayerT', bound=BasePlayer) + class PlayerManager: """ @@ -42,23 +48,25 @@ class PlayerManager: Attributes ---------- - players: :class:`dict` + client: :class:`Client` + The Lavalink client. + players: Dict[int, :class:`BasePlayer`] Cache of all the players that Lavalink has created. """ + __slots__ = ('client', '_player_cls', 'players') - def __init__(self, lavalink, player): + def __init__(self, client, player): if not issubclass(player, BasePlayer): - raise ValueError( - 'Player must implement BasePlayer or DefaultPlayer.') + raise ValueError('Player must implement BasePlayer.') - self._lavalink = lavalink + self.client: 'Client' = client self._player_cls = player self.players: Dict[int, BasePlayer] = {} - def __len__(self): + def __len__(self) -> int: return len(self.players) - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[int, BasePlayer]]: """ Returns an iterator that yields a tuple of (guild_id, player). """ for guild_id, player in self.players.items(): yield guild_id, player @@ -88,7 +96,7 @@ def find_all(self, predicate: Callable[[BasePlayer], bool] = None): return [p for p in self.players.values() if bool(predicate(p))] - def get(self, guild_id: int): + def get(self, guild_id: int) -> Optional[BasePlayer]: """ Gets a player from cache. @@ -118,7 +126,13 @@ def remove(self, guild_id: int): player = self.players.pop(guild_id) player.cleanup() - def create(self, guild_id: int, region: str = None, endpoint: str = None, node: Node = None): + def create(self, + guild_id: int, + *, + region: Optional[str] = None, + endpoint: Optional[str] = None, + node: Optional[Node] = None, + cls: Optional[Type[PlayerT]] = None) -> BasePlayer: """ Creates a player if one doesn't exist with the given information. @@ -137,11 +151,28 @@ def create(self, guild_id: int, region: str = None, endpoint: str = None, node: guild_id: :class:`int` The guild_id to associate with the player. region: Optional[:class:`str`] - The region to use when selecting a Lavalink node. Defaults to ``None``. + The region to use when selecting a Lavalink node. + Defaults to ``None``. endpoint: Optional[:class:`str`] - The address of the Discord voice server. Defaults to ``None``. + The address of the Discord voice server. + Defaults to ``None``. node: Optional[:class:`Node`] - The node to put the player on. Defaults to ``None`` and a node with the lowest penalty is chosen. + The node to put the player on. + Defaults to ``None``, which selects the node with the lowest penalty. + cls: Optional[Type[:class:`BasePlayer`]] + The player class to use when instantiating a new player. + Defaults to ``None`` which uses the player class provided to :class:`Client`. + If no class was provided, this will typically be :class:`DefaultPlayer`. + + Warning + ------- + This function could return a player of a different type to that specified in ``cls``, + if a player was created before with a different class type. + + Raises + ------ + :class:`ValueError` + If the provided ``cls`` is not a valid subclass of :class:`BasePlayer`. Returns ------- @@ -153,16 +184,21 @@ def create(self, guild_id: int, region: str = None, endpoint: str = None, node: if guild_id in self.players: return self.players[guild_id] + cls = cls or self._player_cls + + if not issubclass(cls, BasePlayer): + raise ValueError('Player must implement BasePlayer.') + if endpoint: # Prioritise endpoint over region parameter - region = self._lavalink.node_manager.get_region(endpoint) + region = self.client.node_manager.get_region(endpoint) - best_node = node or self._lavalink.node_manager.find_ideal_node(region) + best_node = node or self.client.node_manager.find_ideal_node(region) if not best_node: - raise NodeError('No available nodes!') + raise ClientError('No available nodes!') id_int = int(guild_id) - self.players[id_int] = player = self._player_cls(id_int, best_node) + self.players[id_int] = player = cls(id_int, best_node) _log.debug('Created player with GuildId %d on node \'%s\'', id_int, best_node.name) return player @@ -190,6 +226,6 @@ async def destroy(self, guild_id: int): player.cleanup() if player.node: - await player.node._send(op='destroy', guildId=player._internal_id) + await player.node.destroy_player(player._internal_id) _log.debug('Destroyed player with GuildId %d on node \'%s\'', guild_id, player.node.name if player.node else 'UNASSIGNED') diff --git a/lavalink/server.py b/lavalink/server.py new file mode 100644 index 00000000..2e90193a --- /dev/null +++ b/lavalink/server.py @@ -0,0 +1,367 @@ +""" +MIT License + +Copyright (c) 2017-present Devoxin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +This module serves to contain all entities which are deserialized using responses from +the Lavalink server. +""" +from enum import Enum as _Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from .errors import InvalidTrack + +if TYPE_CHECKING: + from .abc import DeferredAudioTrack + + +class Enum(_Enum): + def __eq__(self, other): + if self.__class__ is other.__class__: + return self.value == other.value + + if isinstance(other, str): + return self.value.lower() == other.lower() + + return False + + @classmethod + def from_str(cls, other: str): + try: + return cls[other.upper()] + except KeyError: + try: + return cls(other) + except ValueError as error: + raise ValueError(f'{other} is not a valid {cls.__name__} enum!') from error + + +class AudioTrack: + """ + .. _ISRC: https://en.wikipedia.org/wiki/International_Standard_Recording_Code + + Represents an AudioTrack. + + Parameters + ---------- + data: Union[Dict[str, Union[Optional[str], bool, int]], :class:`AudioTrack`] + The data to initialise an AudioTrack from. + requester: :class:`any` + The requester of the track. + extra: Dict[Any, Any] + Any extra information to store in this AudioTrack. + + Attributes + ---------- + track: Optional[:class:`str`] + The base64-encoded string representing a Lavalink-readable AudioTrack. + This is marked optional as it could be None when it's not set by a custom :class:`Source`, + which is expected behaviour when the subclass is a :class:`DeferredAudioTrack`. + identifier: :class:`str` + The track's id. For example, a youtube track's identifier will look like ``dQw4w9WgXcQ``. + is_seekable: :class:`bool` + Whether the track supports seeking. + author: :class:`str` + The track's uploader. + duration: :class:`int` + The duration of the track, in milliseconds. + stream: :class:`bool` + Whether the track is a live-stream. + title: :class:`str` + The title of the track. + uri: :class:`str` + The full URL of track. + artwork_url: Optional[:class:`str`] + A URL pointing to the track's artwork, if applicable. + isrc: Optional[:class:`str`] + The `ISRC`_ for the track, if applicable. + position: :class:`int` + The playback position of the track, in milliseconds. + This is a read-only property; setting it won't have any effect. + source_name: :class:`str` + The name of the source that this track was created by. + requester: :class:`int` + The ID of the user that requested this track. + plugin_info: Optional[Dict[str, Any]] + Addition track info provided by plugins. + user_data: Optional[Dict[str, Any]] + The user data that was attached to the track, if any. + extra: Dict[str, Any] + Any extra properties given to this AudioTrack will be stored here. + """ + __slots__ = ('raw', 'track', 'identifier', 'is_seekable', 'author', 'duration', 'stream', 'title', 'uri', + 'artwork_url', 'isrc', 'position', 'source_name', 'plugin_info', 'user_data', 'extra') + + def __init__(self, data: dict, requester: int = 0, **extra): + if isinstance(data, AudioTrack): + extra = {**data.extra, **extra} + data = data.raw + + self.raw: Dict[str, Union[Optional[str], bool, int]] = data + info = data.get('info', data) + + try: + self.track: Optional[str] = data.get('encoded') + self.identifier: str = info['identifier'] + self.is_seekable: bool = info['isSeekable'] + self.author: str = info['author'] + self.duration: int = info['length'] + self.stream: bool = info['isStream'] + self.title: str = info['title'] + self.uri: str = info['uri'] + self.artwork_url: Optional[str] = info.get('artworkUrl') + self.isrc: Optional[str] = info.get('isrc') + self.position: int = info.get('position', 0) + self.source_name: str = info.get('sourceName', 'unknown') + self.plugin_info: Optional[Dict[str, Any]] = data.get('pluginInfo') + self.user_data: Optional[Dict[str, Any]] = data.get('userData') + self.extra: Dict[str, Any] = {**extra, 'requester': requester} + except KeyError as error: + raise InvalidTrack(f'Cannot build a track from partial data! (Missing key: {error.args[0]})') from error + + def __getitem__(self, name): + if name == 'info': + return self + + return super().__getattribute__(name) + + @classmethod + def from_dict(cls, mapping: dict): + return cls(mapping) + + @property + def requester(self) -> int: + return self.extra['requester'] + + @requester.setter + def requester(self, requester): + self.extra['requester'] = requester + + def __repr__(self): + return f'' + + +class EndReason(Enum): + FINISHED = 'finished' + LOAD_FAILED = 'loadFailed' + STOPPED = 'stopped' + REPLACED = 'replaced' + CLEANUP = 'cleanup' + + def may_start_next(self) -> bool: + """ + Returns whether the next track may be started from this event. + + This is mostly used as a hint to determine whether the ``track_end_event`` should be + responsible for playing the next track. + + Returns + ------- + :class:`bool` + Whether the next track may be started. + """ + return self is EndReason.FINISHED or self is EndReason.LOAD_FAILED + + +class LoadType(Enum): + TRACK = 'TRACK' + PLAYLIST = 'PLAYLIST' + SEARCH = 'SEARCH' + EMPTY = 'EMPTY' + ERROR = 'ERROR' + + +class Severity(Enum): + COMMON = 'common' + SUSPICIOUS = 'suspicious' + FAULT = 'fault' + + +class PlaylistInfo: + """ + Attributes + ---------- + name: :class:`str` + The name of the playlist. + selected_track: :class:`int` + The index of the selected/highlighted track. + This will be -1 if there is no selected track. + """ + __slots__ = ('name', 'selected_track') + + def __init__(self, name: str, selected_track: int = -1): + self.name: str = name + self.selected_track: int = selected_track + + def __getitem__(self, k): # Exists only for compatibility, don't blame me + if k == 'selectedTrack': + k = 'selected_track' + return self.__getattribute__(k) + + @classmethod + def from_dict(cls, mapping: dict): + return cls(mapping.get('name'), mapping.get('selectedTrack', -1)) + + @classmethod + def none(cls): + return cls('', -1) + + def __repr__(self): + return f'' + + +class LoadResultError: + """ + Attributes + ---------- + message: :class:`str` + The error message. + severity: :enum:`Severity` + The severity of the error. + cause: :class:`str` + The cause of the error. + """ + __slots__ = ('message', 'severity', 'cause') + + def __init__(self, error: Dict[str, Any]): + self.message: str = error['message'] + self.severity: Severity = Severity.from_str(error['severity']) + self.cause: str = error['cause'] + + +class LoadResult: + """ + Attributes + ---------- + load_type: :class:`LoadType` + The load type of this result. + tracks: List[Union[:class:`AudioTrack`, :class:`DeferredAudioTrack`]] + The tracks in this result. + playlist_info: Optional[:class:`PlaylistInfo`] + The playlist metadata for this result. + The :class:`PlaylistInfo` could contain empty/false data if the :class:`LoadType` + is not :enum:`LoadType.PLAYLIST`. + plugin_info: Optional[Dict[:class:`str`, Any]] + Additional playlist info provided by plugins. + error: Optional[:class:`LoadResultError`] + The error associated with this ``LoadResult``. + This will be ``None`` if :attr:`load_type` is not :attr:`LoadType.ERROR`. + """ + __slots__ = ('load_type', 'playlist_info', 'tracks', 'plugin_info', 'error') + + def __init__(self, load_type: LoadType, tracks: List[Union[AudioTrack, 'DeferredAudioTrack']], + playlist_info: Optional[PlaylistInfo] = PlaylistInfo.none(), plugin_info: Optional[Dict[str, Any]] = None, + error: Optional[LoadResultError] = None): + self.load_type: LoadType = load_type + self.playlist_info: PlaylistInfo = playlist_info + self.tracks: List[Union[AudioTrack, 'DeferredAudioTrack']] = tracks + self.plugin_info: Optional[Dict[str, Any]] = plugin_info + self.error: Optional[LoadResultError] = error + + def __getitem__(self, k): # Exists only for compatibility, don't blame me + if k == 'loadType': + k = 'load_type' + elif k == 'playlistInfo': + k = 'playlist_info' + + return self.__getattribute__(k) + + @classmethod + def empty(cls): + return LoadResult(LoadType.EMPTY, []) + + @classmethod + def from_dict(cls, mapping: dict): + plugin_info: Optional[dict] = None + playlist_info: Optional[PlaylistInfo] = PlaylistInfo.none() + tracks: Optional[Union[AudioTrack, 'DeferredAudioTrack']] = [] + + data: Union[List[Dict[str, Any]], Dict[str, Any]] = mapping['data'] + load_type = LoadType.from_str(mapping['loadType']) + + if isinstance(data, dict): + plugin_info = data.get('pluginInfo') + + if load_type == LoadType.TRACK: + tracks = [AudioTrack(data, 0)] + elif load_type == LoadType.PLAYLIST: + playlist_info = PlaylistInfo.from_dict(data['info']) + tracks = [AudioTrack(track, 0) for track in data['tracks']] + elif load_type == LoadType.SEARCH: + tracks = [AudioTrack(track, 0) for track in data] + elif load_type == LoadType.ERROR: + error = LoadResultError(data) + return cls(load_type, [], playlist_info, plugin_info, error) + + return cls(load_type, tracks, playlist_info, plugin_info) + + @property + def selected_track(self) -> Optional[AudioTrack]: + """ + Convenience method for returning the selected track using + :attr:`PlaylistInfo.selected_track`. + + This could be ``None`` if :attr:`playlist_info` is ``None``, + or :attr:`PlaylistInfo.selected_track` is an invalid number. + + Returns + ------- + Optional[:class:`AudioTrack`] + """ + if self.playlist_info is not None: + index = self.playlist_info.selected_track + + if 0 <= index < len(self.tracks): + return self.tracks[index] + + return None + + def __repr__(self): + return f'' + + +class Plugin: + """ + Represents a Lavalink server plugin. + + Parameters + ---------- + data: Dict[str, Any] + The data to initialise a Plugin from. + + Attributes + ---------- + name: :class:`str` + The name of the plugin. + version: :class:`str` + The version of the plugin. + """ + __slots__ = ('name', 'version') + + def __init__(self, data: Dict[str, Any]): + self.name: str = data['name'] + self.version: str = data['version'] + + def __str__(self): + return f'{self.name} v{self.version}' + + def __repr__(self): + return f'' diff --git a/lavalink/stats.py b/lavalink/stats.py index cd709bee..b068757c 100644 --- a/lavalink/stats.py +++ b/lavalink/stats.py @@ -120,11 +120,11 @@ def __init__(self, node, data): self.system_load: float = cpu['systemLoad'] self.lavalink_load: float = cpu['lavalinkLoad'] - frame_stats = data.get('frameStats', {}) + frame_stats = data.get('frameStats') or {} self.frames_sent: int = frame_stats.get('sent', 0) self.frames_nulled: int = frame_stats.get('nulled', 0) self.frames_deficit: int = frame_stats.get('deficit', 0) - self.penalty = Penalty(self) + self.penalty: Penalty = Penalty(self) @classmethod def empty(cls, node): diff --git a/lavalink/transport.py b/lavalink/transport.py new file mode 100644 index 00000000..4be805e7 --- /dev/null +++ b/lavalink/transport.py @@ -0,0 +1,360 @@ +""" +MIT License + +Copyright (c) 2017-present Devoxin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import aiohttp + +from .errors import AuthenticationError, ClientError, RequestError +from .events import (IncomingWebSocketMessage, NodeConnectedEvent, + NodeDisconnectedEvent, NodeReadyEvent, PlayerUpdateEvent, + TrackEndEvent, TrackExceptionEvent, TrackStartEvent, + TrackStuckEvent, WebSocketClosedEvent) +from .server import EndReason, Severity +from .stats import Stats + +if TYPE_CHECKING: + from .client import Client + from .node import Node + +_log = logging.getLogger(__name__) +CLOSE_TYPES = ( + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + aiohttp.WSMsgType.CLOSED +) +MESSAGE_QUEUE_MAX_SIZE = 25 +LAVALINK_API_VERSION = 'v4' + + +class Transport: + """ The class responsible for dealing with connections to Lavalink. """ + __slots__ = ('client', '_node', '_session', '_ws', '_message_queue', 'trace_requests', + '_host', '_port', '_password', '_ssl', 'session_id', '_destroyed') + + def __init__(self, node, host: str, port: int, password: str, ssl: bool, session_id: Optional[str]): + self.client: 'Client' = node.client + self._node: 'Node' = node + + self._session: aiohttp.ClientSession = self.client._session + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._message_queue = [] + self.trace_requests = False + + self._host: str = host + self._port: int = port + self._password: str = password + self._ssl: bool = ssl + + self.session_id: Optional[str] = session_id + self._destroyed: bool = False + + self.connect() + + @property + def ws_connected(self): + """ Returns whether the websocket is connected to Lavalink. """ + return self._ws is not None and not self._ws.closed + + @property + def http_uri(self) -> str: + """ Returns a 'base' URI pointing to the node's address and port, also factoring in SSL. """ + return f'{"https" if self._ssl else "http"}://{self._host}:{self._port}' + + async def close(self, code=aiohttp.WSCloseCode.OK): + """|coro| + + Shuts down the websocket connection if there is one. + """ + if self._ws: + await self._ws.close(code=code) + self._ws = None + + def connect(self) -> asyncio.Task: + """ Attempts to establish a connection to Lavalink. """ + loop = asyncio.get_event_loop() + return loop.create_task(self._connect()) + + async def destroy(self): + """|coro| + + Closes the WebSocket gracefully, and stops any further reconnecting. + Useful when needing to remove a node. + """ + self._destroyed = True + await self.close() + + async def _connect(self): + if self._destroyed: + raise IOError('Cannot instantiate any connections with a closed session!') + + if self._ws: + await self.close() + + headers = { + 'Authorization': self._password, + 'User-Id': str(self.client._user_id), + 'Client-Name': f'Lavalink.py/{__import__("lavalink").__version__}' + } + + if self.session_id is not None: + headers['Session-Id'] = self.session_id + + _log.info('[Node:%s] Establishing WebSocket connection to Lavalink...', self._node.name) + + protocol = 'wss' if self._ssl else 'ws' + attempt = 0 + + while not self.ws_connected and not self._destroyed: + attempt += 1 + try: + self._ws = await self._session.ws_connect(f'{protocol}://{self._host}:{self._port}/{LAVALINK_API_VERSION}/websocket', + headers=headers, + heartbeat=60) + except (aiohttp.ClientConnectorError, aiohttp.WSServerHandshakeError, aiohttp.ServerDisconnectedError) as error: + if isinstance(error, aiohttp.ClientConnectorError): + _log.warning('[Node:%s] Invalid response received; this may indicate that ' + 'Lavalink is not running, or is running on a port different ' + 'to the one you provided to `add_node`.', self._node.name) + elif isinstance(error, aiohttp.WSServerHandshakeError): + if error.status in (401, 403): # Special handling for 401/403 (Unauthorized/Forbidden). + _log.warning('[Node:%s] Authentication failed while trying to establish a connection to the node.', + self._node.name) + # We shouldn't try to establish any more connections as correcting this particular error + # would require the cog to be reloaded (or the bot to be rebooted), so further attempts + # would be futile, and a waste of resources. + else: + _log.warning('[Node:%s] The remote server returned code %d, the expected code was 101. This usually ' + 'indicates that the remote server is a webserver and not Lavalink. Check your ports, ' + 'and try again.', self._node.name, error.status) + + return + else: + _log.exception('[Node:%s] An unknown error occurred whilst trying to establish a connection to Lavalink', self._node.name) + + backoff = min(10 * attempt, 60) + await asyncio.sleep(backoff) + else: + _log.info('[Node:%s] WebSocket connection established', self._node.name) + await self.client._dispatch_event(NodeConnectedEvent(self._node)) + + if self._message_queue: + for message in self._message_queue: + await self._send(**message) + + self._message_queue.clear() + + attempt = 0 + await self._listen() + + async def _listen(self): + """ Listens for websocket messages. """ + close_code = None + close_reason = 'Improper websocket closure' + + async for msg in self._ws: + _log.debug('[Node:%s] Received WebSocket message: %s', self._node.name, msg.data) + + if msg.type == aiohttp.WSMsgType.TEXT: + await self._handle_message(msg.json()) + elif msg.type == aiohttp.WSMsgType.ERROR: + exc = self._ws.exception() + _log.error('[Node:%s] Exception in WebSocket!', self._node.name, exc_info=exc) + close_code = aiohttp.WSCloseCode.INTERNAL_ERROR + close_reason = 'WebSocket error' + break + elif msg.type in CLOSE_TYPES: + _log.debug('[Node:%s] Received close frame with code %d.', self._node.name, msg.data) + close_code = msg.data + close_reason = msg.extra + break + + close_code = close_code or self._ws.close_code + await self.close(close_code or aiohttp.WSCloseCode.OK) + await self._websocket_closed(close_code, close_reason) + + async def _websocket_closed(self, code: Optional[int] = None, reason: Optional[str] = None): + """ + Handles when the websocket is closed. + + Parameters + ---------- + code: Optional[:class:`int`] + The response code. + reason: Optional[:class:`str`] + Reason why the websocket was closed. Defaults to ``None``. + """ + _log.warning('[Node:%s] WebSocket disconnected with the following: code=%s reason=%s', self._node.name, code, reason) + self._ws = None + await self._node.manager._handle_node_disconnect(self._node) + await self.client._dispatch_event(NodeDisconnectedEvent(self._node, code, reason)) + + async def _handle_message(self, data: Union[Dict[Any, Any], List[Any]]): + """ + Handles the response from the websocket. + + Parameters + ---------- + data: Union[Dict[Any, Any], List[Any]] + The payload received from the Lavalink server. + """ + if self.client.has_listeners(IncomingWebSocketMessage): + await self.client._dispatch_event(IncomingWebSocketMessage(data.copy(), self._node)) + + if not isinstance(data, dict) or 'op' not in data: + return + + op = data['op'] # pylint: disable=C0103 + + if op == 'ready': + self.session_id = data['sessionId'] + await self._node.manager._handle_node_ready(self._node) + await self.client._dispatch_event(NodeReadyEvent(self, data['sessionId'], data['resumed'])) + elif op == 'playerUpdate': + guild_id = int(data['guildId']) + player = self.client.player_manager.get(guild_id) + + if not player: + _log.debug('[Node:%s] Received playerUpdate for non-existent player! GuildId: %d', self._node.name, guild_id) + return + + state = data['state'] + await player._update_state(state) + await self.client._dispatch_event(PlayerUpdateEvent(player, state)) + elif op == 'stats': + self._node.stats = Stats(self._node, data) + elif op == 'event': + await self._handle_event(data) + else: + _log.warning('[Node:%s] Received unknown op: %s', self._node.name, op) + + async def _handle_event(self, data: dict): + """ + Handles the event from Lavalink. + + Parameters + ---------- + data: :class:`dict` + The data given from Lavalink. + """ + player = self.client.player_manager.get(int(data['guildId'])) + event_type = data['type'] + + if not player: + if event_type not in ('TrackEndEvent', 'WebSocketClosedEvent'): # Player was most likely destroyed if it's any of these. + _log.warning('[Node:%s] Received event type %s for non-existent player! GuildId: %s', self._node.name, event_type, data['guildId']) + return + + event = None + + if event_type == 'TrackStartEvent': # Always fired after track end event (for previous track), and before any track exception/stuck events. + player.current = player._next + player._next = None + event = TrackStartEvent(player, player.current) + elif event_type == 'TrackEndEvent': + end_reason = EndReason.from_str(data['reason']) + event = TrackEndEvent(player, player.current, end_reason) + elif event_type == 'TrackExceptionEvent': + exception = data['exception'] + message = exception['message'] + severity = Severity.from_str(exception['severity']) + cause = exception['cause'] + event = TrackExceptionEvent(player, player.current, message, severity, cause) + elif event_type == 'TrackStuckEvent': + event = TrackStuckEvent(player, player.current, data['thresholdMs']) + elif event_type == 'WebSocketClosedEvent': + event = WebSocketClosedEvent(player, data['code'], data['reason'], data['byRemote']) + else: + _log.warning('[Node:%s] Unknown event received of type \'%s\'', self._node.name, event_type) + return + + await self.client._dispatch_event(event) + + if player: + try: + await player._handle_event(event) + except: # noqa: E722 pylint: disable=bare-except + _log.exception('Player %d encountered an error whilst handling event %s', player.guild_id, type(event).__name__) + + async def _send(self, **data): + """ + Sends a payload to Lavalink. + + Parameters + ---------- + data: :class:`dict` + The data sent to Lavalink. + """ + if not self.ws_connected: + _log.debug('[Node:%s] WebSocket not ready; queued outgoing payload.', self._node.name) + + if len(self._message_queue) >= MESSAGE_QUEUE_MAX_SIZE: + _log.warning('[Node:%s] WebSocket message queue is currently at capacity, discarding payload.', self._node.name) + else: + self._message_queue.append(data) + return + + _log.debug('[Node:%s] Sending payload %s', self._node.name, str(data)) + try: + await self._ws.send_json(data) + except ConnectionResetError: + _log.warning('[Node:%s] Failed to send payload due to connection reset!', self._node.name) + + async def _request(self, method: str, path: str, to=None, trace: bool = False, versioned: bool = True, **kwargs): # pylint: disable=C0103 + if self._destroyed: + raise IOError('Cannot instantiate any connections with a closed session!') + + if trace is True or self.trace_requests is True: + kwargs['params'] = {**kwargs.get('params', {}), 'trace': 'true'} + + if versioned: + request_url = f'{self.http_uri}/{LAVALINK_API_VERSION}/{path}' + else: + request_url = f'{self.http_uri}/{path}' + + _log.debug('[Node:%s] Sending request to Lavalink with the following parameters: method=%s, url=%s, params=%s, json=%s', + self._node.name, method, request_url, kwargs.get('params', {}), kwargs.get('json', {})) + + try: + async with self._session.request(method=method, url=request_url, + headers={'Authorization': self._password}, **kwargs) as res: + if res.status in (401, 403): + raise AuthenticationError + + if res.status == 200: + if to is str: + return await res.text() + + json = await res.json() + return json if to is None else to.from_dict(json) + + if res.status == 204: + return True + + raise RequestError('An invalid response was received from the node.', + status=res.status, response=await res.json(), params=kwargs.get('params', {})) + except aiohttp.ClientConnectorError as cce: + _log.error('Request "%s %s" failed', method, path) + raise ClientError from cce diff --git a/lavalink/utfm_codec.py b/lavalink/utfm_codec.py index 9474fd64..b8cd1ea2 100644 --- a/lavalink/utfm_codec.py +++ b/lavalink/utfm_codec.py @@ -28,29 +28,29 @@ def read_utfm(utf_len: int, utf_bytes: bytes) -> str: count = 0 while count < utf_len: - c = utf_bytes[count] & 0xff - if c > 127: + char = utf_bytes[count] & 0xff + if char > 127: break count += 1 - chars.append(chr(c)) + chars.append(chr(char)) while count < utf_len: - c = utf_bytes[count] & 0xff - shift = c >> 4 + char = utf_bytes[count] & 0xff + shift = char >> 4 if 0 <= shift <= 7: count += 1 - chars.append(chr(c)) + chars.append(chr(char)) elif 12 <= shift <= 13: count += 2 if count > utf_len: raise UnicodeDecodeError('malformed input: partial character at end') char2 = utf_bytes[count - 1] if (char2 & 0xC0) != 0x80: - raise UnicodeDecodeError('malformed input around byte ' + count) + raise UnicodeDecodeError(f'malformed input around byte {count}') - char_shift = ((c & 0x1F) << 6) | (char2 & 0x3F) + char_shift = ((char & 0x1F) << 6) | (char2 & 0x3F) chars.append(chr(char_shift)) elif shift == 14: count += 3 @@ -61,11 +61,11 @@ def read_utfm(utf_len: int, utf_bytes: bytes) -> str: char3 = utf_bytes[count - 1] if (char2 & 0xC0) != 0x80 or (char3 & 0xC0) != 0x80: - raise UnicodeDecodeError('malformed input around byte ' + (count - 1)) + raise UnicodeDecodeError(f'malformed input around byte {(count - 1)}') - char_shift = ((c & 0x0F) << 12) | ((char2 & 0x3F) << 6) | ((char3 & 0x3F) << 0) + char_shift = ((char & 0x0F) << 12) | ((char2 & 0x3F) << 6) | ((char3 & 0x3F) << 0) chars.append(chr(char_shift)) else: - raise UnicodeDecodeError('malformed input around byte ' + count) + raise UnicodeDecodeError(f'malformed input around byte {count}') return ''.join(chars).encode('utf-16', 'surrogatepass').decode('utf-16') diff --git a/lavalink/utils.py b/lavalink/utils.py index 6c5414d0..6a05cd05 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -23,10 +23,14 @@ """ import struct from base64 import b64encode -from typing import Tuple +from typing import Dict, Optional, Tuple, Union -from .datarw import DataReader, DataWriter -from .models import AudioTrack +from .dataio import DataReader, DataWriter +from .errors import InvalidTrack +from .player import AudioTrack + +V2_KEYSET = {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'sourceName', 'position'} +V3_KEYSET = {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'artworkUrl', 'isrc', 'sourceName', 'position'} def timestamp_to_millis(timestamp: str) -> int: @@ -51,8 +55,8 @@ def timestamp_to_millis(timestamp: str) -> int: """ try: sections = list(map(int, timestamp.split(':'))) - except ValueError as ve: - raise ValueError('Timestamp should consist of integers and colons only') from ve + except ValueError as error: + raise ValueError('Timestamp should consist of integers and colons only') from error if not sections: raise TypeError('An invalid timestamp was provided, a timestamp should look like 1:30') @@ -61,19 +65,19 @@ def timestamp_to_millis(timestamp: str) -> int: raise TypeError('Too many segments within the provided timestamp! Provide no more than 4 segments.') if len(sections) == 4: - d, h, m, s = map(int, sections) - return (d * 86400000) + (h * 3600000) + (m * 60000) + (s * 1000) + days, hours, minutes, seconds = map(int, sections) + return (days * 86400000) + (hours * 3600000) + (minutes * 60000) + (seconds * 1000) if len(sections) == 3: - h, m, s = map(int, sections) - return (h * 3600000) + (m * 60000) + (s * 1000) + hours, minutes, seconds = map(int, sections) + return (hours * 3600000) + (minutes * 60000) + (seconds * 1000) if len(sections) == 2: - m, s = map(int, sections) - return (m * 60000) + (s * 1000) + minutes, seconds = map(int, sections) + return (minutes * 60000) + (seconds * 1000) - s, = map(int, sections) - return s * 1000 + seconds, = map(int, sections) + return seconds * 1000 def format_time(time: int) -> str: @@ -92,7 +96,7 @@ def format_time(time: int) -> str: hours, remainder = divmod(time / 1000, 3600) minutes, seconds = divmod(remainder, 60) - return '%02d:%02d:%02d' % (hours, minutes, seconds) + return f'{hours:02.0f}:{minutes:02.0f}:{seconds:02.0f}' def parse_time(time: int) -> Tuple[int, int, int, int]: @@ -116,6 +120,33 @@ def parse_time(time: int) -> Tuple[int, int, int, int]: return days, hours, minutes, seconds +def _read_track_common(reader: DataReader) -> Tuple[str, str, int, str, bool, Optional[str]]: + """ + Reads common fields between v1-3 AudioTracks. + + Returns + ------- + Tuple[str, str, int, str, bool, Optional[str]] + A tuple containing (title, author, length, identifier, isStream, uri) fields. + """ + title = reader.read_utfm() + author = reader.read_utfm() + length = reader.read_long() + identifier = reader.read_utf().decode() + is_stream = reader.read_boolean() + uri = reader.read_nullable_utf() + return (title, author, length, identifier, is_stream, uri) + + +def _write_track_common(track: Dict[str, Union[Optional[str], bool, int]], writer: DataWriter): + writer.write_utf(track['title']) + writer.write_utf(track['author']) + writer.write_long(track['length']) + writer.write_utf(track['identifier']) + writer.write_boolean(track['isStream']) + writer.write_nullable_utf(track['uri']) + + def decode_track(track: str) -> AudioTrack: """ Decodes a base64 track string into an AudioTrack object. @@ -132,14 +163,15 @@ def decode_track(track: str) -> AudioTrack: reader = DataReader(track) flags = (reader.read_int() & 0xC0000000) >> 30 - version = struct.unpack('B', reader.read_byte()) if flags & 1 != 0 else 1 + version, = struct.unpack('B', reader.read_byte()) if flags & 1 != 0 else 1 + + title, author, length, identifier, is_stream, uri = _read_track_common(reader) + extra_fields = {} + + if version == 3: + extra_fields['artworkUrl'] = reader.read_nullable_utf() + extra_fields['isrc'] = reader.read_nullable_utf() - title = reader.read_utfm() - author = reader.read_utfm() - length = reader.read_long() - identifier = reader.read_utf().decode() - is_stream = reader.read_boolean() - uri = reader.read_utf().decode() if reader.read_boolean() else None source = reader.read_utf().decode() position = reader.read_long() @@ -153,30 +185,79 @@ def decode_track(track: str) -> AudioTrack: 'isStream': is_stream, 'uri': uri, 'isSeekable': not is_stream, - 'sourceName': source + 'sourceName': source, + **extra_fields } } return AudioTrack(track_object, 0, position=position, encoder_version=version) -def encode_track(track: dict) -> str: - assert {'title', 'author', 'length', 'identifier', 'isStream', 'uri', 'sourceName', 'position'} == track.keys() +def encode_track(track: Dict[str, Union[Optional[str], int, bool]]) -> Tuple[int, str]: + """ + Encodes a track dict into a base64 string, readable by the Lavalink server. + + A track should have *at least* the following keys: + ``title``, ``author``, ``length``, ``identifier``, ``isStream``, ``uri``, ``sourceName`` and ``position``. + + If the track is a v3 track, it should have the following additional fields: + ``artworkUrl`` and ``isrc``. isrc can be ``None`` if not applicable. + + Parameters + ---------- + track: Dict[str, Union[Optional[str], int, bool]] + The track dict to serialize. + + Raises + ------ + :class:`InvalidTrack` + If the track has unexpected, or missing keys, possibly due to an incompatible version or another reason. + + Returns + ------- + Tuple[int, str] + A tuple containing (track_version, encoded_track). + For example, if a track was encoded as version 3, the return value will be ``(3, '...really long track string...')``. + """ + track_keys = track.keys() # set(track) is faster for larger collections, but slower for smaller. + + if not V2_KEYSET <= track_keys: # V2_KEYSET contains the minimum number of fields required to successfully encode a track. + missing_keys = [k for k in V2_KEYSET if k not in track] + + raise InvalidTrack(f'Track object is missing keys required for serialization: {", ".join(missing_keys)}') + + if V3_KEYSET <= track_keys: + return (3, encode_track_v3(track)) + + return (2, encode_track_v2(track)) + + +def encode_track_v2(track: Dict[str, Union[Optional[str], bool, int]]) -> str: + assert V2_KEYSET <= track.keys() writer = DataWriter() version = struct.pack('B', 2) writer.write_byte(version) - writer.write_utf(track['title']) - writer.write_utf(track['author']) - writer.write_long(track['length']) - writer.write_utf(track['identifier']) - writer.write_boolean(track['isStream']) - writer.write_boolean(track['uri']) - writer.write_utf(track['uri']) + _write_track_common(track, writer) + writer.write_utf(track['sourceName']) + writer.write_long(track['position']) + + enc = writer.finish() + return b64encode(enc).decode() + + +def encode_track_v3(track: Dict[str, Union[Optional[str], bool, int]]) -> str: + assert V3_KEYSET <= track.keys() + + writer = DataWriter() + version = struct.pack('B', 3) + writer.write_byte(version) + _write_track_common(track, writer) + writer.write_nullable_utf(track['artworkUrl']) + writer.write_nullable_utf(track['isrc']) writer.write_utf(track['sourceName']) writer.write_long(track['position']) enc = writer.finish() - b64 = b64encode(enc) - return b64 + return b64encode(enc).decode() diff --git a/lavalink/websocket.py b/lavalink/websocket.py deleted file mode 100644 index 30c126ae..00000000 --- a/lavalink/websocket.py +++ /dev/null @@ -1,291 +0,0 @@ -""" -MIT License - -Copyright (c) 2017-present Devoxin - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" -import asyncio -import logging - -import aiohttp - -from .events import (PlayerUpdateEvent, TrackEndEvent, TrackExceptionEvent, - TrackStuckEvent, WebSocketClosedEvent) -from .stats import Stats -from .utils import decode_track - -_log = logging.getLogger(__name__) -CLOSE_TYPES = ( - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - aiohttp.WSMsgType.CLOSED -) - - -class WebSocket: - """ Represents the WebSocket connection with Lavalink. """ - def __init__(self, node, host: str, port: int, password: str, ssl: bool, resume_key: str, - resume_timeout: int, reconnect_attempts: int): - self._node = node - self._lavalink = self._node._manager._lavalink - - self._session = self._lavalink._session - self._ws = None - self._message_queue = [] - - self._host = host - self._port = port - self._password = password - self._ssl = ssl - self._max_reconnect_attempts = reconnect_attempts - - self._resume_key = resume_key - self._resume_timeout = resume_timeout - self._resuming_configured = False - self._destroyed = False - - self.connect() - - @property - def connected(self): - """ Returns whether the websocket is connected to Lavalink. """ - return self._ws is not None and not self._ws.closed - - async def close(self, code=aiohttp.WSCloseCode.OK): - """|coro| - - Shuts down the websocket connection if there is one. - """ - if self._ws: - await self._ws.close(code=code) - self._ws = None - - def connect(self): - """ Attempts to establish a connection to Lavalink. """ - return asyncio.ensure_future(self._connect()) - - async def destroy(self): - """|coro| - - Closes the WebSocket gracefully, and stops any further reconnecting. - Useful when needing to remove a node. - """ - self._destroyed = True - await self.close() - - async def _connect(self): - if self._ws: - await self.close() - - headers = { - 'Authorization': self._password, - 'User-Id': str(self._lavalink._user_id), - 'Client-Name': 'Lavalink.py', - 'Num-Shards': '1' # Legacy header that is no longer used. Here for compatibility. - } - # TODO: 'User-Agent': 'Lavalink.py/{} (https://github.com/devoxin/lavalink.py)'.format(__version__) - - if self._resuming_configured and self._resume_key: - headers['Resume-Key'] = self._resume_key - - is_finite_retry = self._max_reconnect_attempts != -1 - max_attempts_str = self._max_reconnect_attempts if is_finite_retry else 'inf' - attempt = 0 - - while not self.connected and (not is_finite_retry or attempt < self._max_reconnect_attempts): - attempt += 1 - _log.info('[Node:%s] Attempting to establish WebSocket connection (%d/%s)...', self._node.name, attempt, max_attempts_str) - - protocol = 'wss' if self._ssl else 'ws' - try: - self._ws = await self._session.ws_connect('{}://{}:{}'.format(protocol, self._host, self._port), - headers=headers, heartbeat=60) - except (aiohttp.ClientConnectorError, aiohttp.WSServerHandshakeError, aiohttp.ServerDisconnectedError, asyncio.exceptions.TimeoutError) as ce: - if isinstance(ce, aiohttp.ClientConnectorError): - _log.warning('[Node:%s] Invalid response received; this may indicate that ' - 'Lavalink is not running, or is running on a port different ' - 'to the one you provided to `add_node`.', self._node.name) - elif isinstance(ce, aiohttp.WSServerHandshakeError): - if ce.status in (401, 403): # Special handling for 401/403 (Unauthorized/Forbidden). - _log.warning('[Node:%s] Authentication failed while trying to establish a connection to the node.', - self._node.name) - # We shouldn't try to establish any more connections as correcting this particular error - # would require the cog to be reloaded (or the bot to be rebooted), so further attempts - # would be futile, and a waste of resources. - return - - _log.warning('[Node:%s] The remote server returned code %d, the expected code was 101. This usually ' - 'indicates that the remote server is a webserver and not Lavalink. Check your ports, ' - 'and try again.', self._node.name, ce.status) - elif isinstance(ce, asyncio.exceptions.TimeoutError): - _log.warning('[Node:%s] The remote server is not responding.', self._node.name) - else: - _log.exception('[Node:%s] An unknown error occurred whilst trying to establish ' - 'a connection to Lavalink', self._node.name) - backoff = min(10 * attempt, 60) - await asyncio.sleep(backoff) - else: - _log.info('[Node:%s] WebSocket connection established', self._node.name) - await self._node._manager._node_connect(self._node) - - if not self._resuming_configured and self._resume_key \ - and (self._resume_timeout and self._resume_timeout > 0): - await self._send(op='configureResuming', key=self._resume_key, timeout=self._resume_timeout) - self._resuming_configured = True - - if self._message_queue: - for message in self._message_queue: - await self._send(**message) - - self._message_queue.clear() - - await self._listen() - # Ensure this loop doesn't proceed if _listen returns control back to this function. - return - - _log.warning('[Node:%s] A WebSocket connection could not be established within %s attempts.', self._node.name, max_attempts_str) - - async def _listen(self): - """ Listens for websocket messages. """ - async for msg in self._ws: - _log.debug('[Node:%s] Received WebSocket message: %s', self._node.name, msg.data) - - if msg.type == aiohttp.WSMsgType.TEXT: - await self._handle_message(msg.json()) - elif msg.type == aiohttp.WSMsgType.ERROR: - exc = self._ws.exception() - _log.error('[Node:%s] Exception in WebSocket!', self._node.name, exc_info=exc) - break - elif msg.type in CLOSE_TYPES: - _log.debug('[Node:%s] Received close frame with code %d.', self._node.name, msg.data) - await self._websocket_closed(msg.data, msg.extra) - return - await self._websocket_closed(self._ws.close_code, 'AsyncIterator loop exited') - - async def _websocket_closed(self, code: int = None, reason: str = None): - """ - Handles when the websocket is closed. - - Parameters - ---------- - code: Optional[:class:`int`] - The response code. - reason: Optional[:class:`str`] - Reason why the websocket was closed. Defaults to ``None`` - """ - _log.warning('[Node:%s] WebSocket disconnected with the following: code=%d reason=%s', self._node.name, code, reason) - self._ws = None - await self._node._manager._node_disconnect(self._node, code, reason) - - if not self._destroyed: - await self._connect() - - async def _handle_message(self, data: dict): - """ - Handles the response from the websocket. - - Parameters - ---------- - data: :class:`dict` - The data given from Lavalink. - """ - op = data['op'] - - if op == 'stats': - self._node.stats = Stats(self._node, data) - elif op == 'playerUpdate': - player_id = data['guildId'] - player = self._lavalink.player_manager.get(int(player_id)) - - if not player: - _log.debug('[Node:%s] Received playerUpdate for non-existent player! GuildId: %s', self._node.name, player_id) - return - - await player._update_state(data['state']) - await self._lavalink._dispatch_event(PlayerUpdateEvent(player, data['state'])) - elif op == 'event': - await self._handle_event(data) - else: - _log.warning('[Node:%s] Received unknown op: %s', self._node.name, op) - - async def _handle_event(self, data: dict): - """ - Handles the event from Lavalink. - - Parameters - ---------- - data: :class:`dict` - The data given from Lavalink. - """ - player = self._lavalink.player_manager.get(int(data['guildId'])) - event_type = data['type'] - - if not player: - if event_type not in ('TrackEndEvent', 'WebSocketClosedEvent'): # Player was most likely destroyed if it's any of these. - _log.warning('[Node:%s] Received event type %s for non-existent player! GuildId: %s', self._node.name, event_type, data['guildId']) - return - - event = None - - if event_type == 'TrackEndEvent': - track = decode_track(data['track']) if data['track'] else None - event = TrackEndEvent(player, track, data['reason']) - elif event_type == 'TrackExceptionEvent': - exc_inner = data.get('exception', {}) - exception = data.get('error') or exc_inner.get('cause', 'Unknown exception') - severity = exc_inner.get('severity', 'UNKNOWN') - event = TrackExceptionEvent(player, player.current, exception, severity) - # elif event_type == 'TrackStartEvent': - # event = TrackStartEvent(player, player.current) - elif event_type == 'TrackStuckEvent': - event = TrackStuckEvent(player, player.current, data['thresholdMs']) - elif event_type == 'WebSocketClosedEvent': - event = WebSocketClosedEvent(player, data['code'], data['reason'], data['byRemote']) - else: - if event_type == 'TrackStartEvent': - return - - _log.warning('[Node:%s] Unknown event received of type \'%s\'', self._node.name, event_type) - return - - await self._lavalink._dispatch_event(event) - - if player: - await player._handle_event(event) - - async def _send(self, **data): - """ - Sends a payload to Lavalink. - - Parameters - ---------- - data: :class:`dict` - The data sent to Lavalink. - """ - if not self.connected: - _log.debug('[Node:%s] WebSocket not ready; queued outgoing payload', self._node.name) - self._message_queue.append(data) - return - - _log.debug('[Node:%s] Sending payload %s', self._node.name, str(data)) - try: - await self._ws.send_json(data) - except ConnectionResetError: - _log.warning('[Node:%s] Failed to send payload due to connection reset!', self._node.name) diff --git a/run_tests.py b/run_tests.py index 7e1e3d8a..6772f871 100644 --- a/run_tests.py +++ b/run_tests.py @@ -23,10 +23,9 @@ def test_pylint(): reporter = text.TextReporter(stdout) opts = ['--max-line-length=150', '--score=no', '--disable=missing-docstring,wildcard-import,' 'attribute-defined-outside-init,too-few-public-methods,' - 'old-style-class,import-error,invalid-name,no-init,' - 'too-many-instance-attributes,protected-access,too-many-arguments,' - 'too-many-public-methods,logging-format-interpolation,' - 'too-many-branches', 'lavalink'] + 'too-many-instance-attributes,protected-access,' + 'too-many-arguments,too-many-public-methods,too-many-branches,' + 'consider-using-with', 'lavalink'] pylint.Run(opts, reporter=reporter, do_exit=False) out = reporter.out.getvalue() diff --git a/setup.py b/setup.py index 25076712..5e459c8a 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ download_url='https://github.com/Devoxin/Lavalink.py/archive/{}.tar.gz'.format(version), keywords=['lavalink'], include_package_data=True, - install_requires=['aiohttp>=3.7.4,<3.9.0'], + install_requires=['aiohttp>=3.9.0,<4'], extras_require={'docs': ['sphinx', 'pygments', 'guzzle_sphinx_theme',