From ad4e85f9957dd4ffde479551e8997bee4aab642f Mon Sep 17 00:00:00 2001 From: vi <8530778+shiftinv@users.noreply.github.com> Date: Wed, 25 Dec 2024 17:39:41 +0100 Subject: [PATCH 1/4] feat(channel): add `GroupChannel.get_partial_message` (#1256) --- changelog/1256.feature.rst | 1 + disnake/channel.py | 22 ++++++++++++++++++++++ disnake/message.py | 4 +++- 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 changelog/1256.feature.rst diff --git a/changelog/1256.feature.rst b/changelog/1256.feature.rst new file mode 100644 index 0000000000..56996d6e38 --- /dev/null +++ b/changelog/1256.feature.rst @@ -0,0 +1 @@ +Add :meth:`GroupChannel.get_partial_message`. diff --git a/disnake/channel.py b/disnake/channel.py index 90d5792a39..55b35562cd 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -4968,6 +4968,28 @@ def permissions_for( return base + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the given message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 2.10 + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message object. + """ + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + async def leave(self) -> None: """|coro| diff --git a/disnake/message.py b/disnake/message.py index efe965ae13..e4cb7ee0e6 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -2511,6 +2511,7 @@ class PartialMessage(Hashable): - :meth:`StageChannel.get_partial_message` - :meth:`Thread.get_partial_message` - :meth:`DMChannel.get_partial_message` + - :meth:`GroupChannel.get_partial_message` - :meth:`PartialMessageable.get_partial_message` Note that this class is trimmed down and has no rich attributes. @@ -2560,6 +2561,7 @@ def __init__(self, *, channel: MessageableChannel, id: int) -> None: ChannelType.text, ChannelType.news, ChannelType.private, + ChannelType.group, ChannelType.news_thread, ChannelType.public_thread, ChannelType.private_thread, @@ -2567,7 +2569,7 @@ def __init__(self, *, channel: MessageableChannel, id: int) -> None: ChannelType.stage_voice, ): raise TypeError( - f"Expected TextChannel, VoiceChannel, DMChannel, StageChannel, Thread, or PartialMessageable " + f"Expected TextChannel, VoiceChannel, StageChannel, Thread, DMChannel, GroupChannel, or PartialMessageable " f"with a valid type, not {type(channel)!r} (type: {channel.type!r})" ) From 2867a91ca15f06f6f5387cd74cdb6b72472865d5 Mon Sep 17 00:00:00 2001 From: vi <8530778+shiftinv@users.noreply.github.com> Date: Wed, 25 Dec 2024 17:49:02 +0100 Subject: [PATCH 2/4] feat: soundboard (#1068) --- changelog/1068.feature.0.rst | 9 + changelog/1068.feature.1.rst | 1 + disnake/__init__.py | 1 + disnake/asset.py | 2 +- disnake/audit_logs.py | 20 ++ disnake/channel.py | 57 ++++- disnake/client.py | 47 +++- disnake/enums.py | 30 ++- disnake/ext/commands/converter.py | 40 ++++ disnake/ext/commands/errors.py | 19 ++ disnake/flags.py | 42 +++- disnake/guild.py | 144 +++++++++++- disnake/http.py | 121 ++++++++++- disnake/soundboard.py | 313 +++++++++++++++++++++++++++ disnake/state.py | 180 +++++++++++++-- disnake/types/audit_log.py | 3 + disnake/types/gateway.py | 21 ++ disnake/types/guild.py | 4 + disnake/types/soundboard.py | 31 +++ disnake/types/voice.py | 2 + disnake/utils.py | 18 +- docs/api/audit_logs.rst | 54 ++++- docs/api/events.rst | 22 +- docs/api/index.rst | 1 + docs/api/soundboard.rst | 45 ++++ docs/ext/commands/api/converters.rst | 3 + docs/ext/commands/api/exceptions.rst | 4 + docs/ext/commands/commands.rst | 3 + tests/test_utils.py | 19 +- 29 files changed, 1187 insertions(+), 69 deletions(-) create mode 100644 changelog/1068.feature.0.rst create mode 100644 changelog/1068.feature.1.rst create mode 100644 disnake/soundboard.py create mode 100644 disnake/types/soundboard.py create mode 100644 docs/api/soundboard.rst diff --git a/changelog/1068.feature.0.rst b/changelog/1068.feature.0.rst new file mode 100644 index 0000000000..7bcf5e83c1 --- /dev/null +++ b/changelog/1068.feature.0.rst @@ -0,0 +1,9 @@ +Implement soundboard features. +- Sound models: :class:`PartialSoundboardSound`, :class:`SoundboardSound`, :class:`GuildSoundboardSound` +- Managing sounds: + - Get soundboard sounds using :attr:`Guild.soundboard_sounds`, :attr:`Client.get_soundboard_sound`, or fetch them using :meth:`Guild.fetch_soundboard_sound`, :meth:`Guild.fetch_soundboard_sounds`, or :meth:`Client.fetch_default_soundboard_sounds` + - New sounds can be created with :meth:`Guild.create_soundboard_sound` + - Handle guild soundboard sound updates using the :attr:`~Event.guild_soundboard_sounds_update` event +- Send sounds using :meth:`VoiceChannel.send_soundboard_sound` +- New attributes: :attr:`Guild.soundboard_limit`, :attr:`VoiceChannelEffect.sound`, :attr:`Client.soundboard_sounds` +- New audit log actions: :attr:`AuditLogAction.soundboard_sound_create`, :attr:`~AuditLogAction.soundboard_sound_update`, :attr:`~AuditLogAction.soundboard_sound_delete` diff --git a/changelog/1068.feature.1.rst b/changelog/1068.feature.1.rst new file mode 100644 index 0000000000..ff1ed09c5e --- /dev/null +++ b/changelog/1068.feature.1.rst @@ -0,0 +1 @@ +Rename :attr:`Intents.emojis_and_stickers` to :attr:`Intents.expressions`. An alias is provided for backwards compatibility. diff --git a/disnake/__init__.py b/disnake/__init__.py index e0af7f3354..705de41175 100644 --- a/disnake/__init__.py +++ b/disnake/__init__.py @@ -63,6 +63,7 @@ from .role import * from .shard import * from .sku import * +from .soundboard import * from .stage_instance import * from .sticker import * from .team import * diff --git a/disnake/asset.py b/disnake/asset.py index cf3cf5ac66..30fdd31aab 100644 --- a/disnake/asset.py +++ b/disnake/asset.py @@ -154,7 +154,7 @@ async def to_file( # if the filename doesn't have an extension (e.g. widget member avatars), # try to infer it from the data if not os.path.splitext(filename)[1]: - ext = utils._get_extension_for_image(data) + ext = utils._get_extension_for_data(data) if ext: filename += ext diff --git a/disnake/audit_logs.py b/disnake/audit_logs.py index cc2948f9a3..1a71095822 100644 --- a/disnake/audit_logs.py +++ b/disnake/audit_logs.py @@ -365,6 +365,7 @@ class AuditLogChanges: "available_tags": (None, _list_transformer(_transform_tag)), "default_reaction_emoji": ("default_reaction", _transform_default_reaction), "default_sort_order": (None, _enum_transformer(enums.ThreadSortOrder)), + "sound_id": ("id", _transform_snowflake), } # fmt: on @@ -372,6 +373,8 @@ def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]) -> N self.before = AuditLogDiff() self.after = AuditLogDiff() + has_emoji_fields = False + for elem in data: attr = elem["key"] @@ -390,6 +393,10 @@ def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]) -> N ) continue + # special case for flat emoji fields (discord, why), these will be merged later + if attr == "emoji_id" or attr == "emoji_name": + has_emoji_fields = True + transformer: Optional[Transformer] try: @@ -420,6 +427,9 @@ def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]) -> N setattr(self.after, attr, after) + if has_emoji_fields: + self._merge_emoji(entry) + # add an alias if hasattr(self.after, "colour"): self.after.color = self.after.colour @@ -478,6 +488,16 @@ def _handle_command_permissions( data=new, guild_id=guild_id ) + def _merge_emoji(self, entry: AuditLogEntry) -> None: + for diff in (self.before, self.after): + emoji_id: Optional[str] = diff.__dict__.pop("emoji_id", None) + emoji_name: Optional[str] = diff.__dict__.pop("emoji_name", None) + + diff.emoji = entry._state._get_emoji_from_fields( + name=emoji_name, + id=int(emoji_id) if emoji_id else None, + ) + class _AuditLogProxyMemberPrune: delete_member_days: int diff --git a/disnake/channel.py b/disnake/channel.py index 55b35562cd..d7c72ca5a1 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -46,6 +46,7 @@ from .mixins import Hashable from .partial_emoji import PartialEmoji from .permissions import PermissionOverwrite, Permissions +from .soundboard import GuildSoundboardSound, PartialSoundboardSound, SoundboardSound from .stage_instance import StageInstance from .threads import ForumTag, Thread from .utils import MISSING @@ -91,6 +92,7 @@ VoiceChannel as VoiceChannelPayload, ) from .types.snowflake import SnowflakeList + from .types.soundboard import PartialSoundboardSound as PartialSoundboardSoundPayload from .types.threads import ThreadArchiveDurationLiteral from .types.voice import VoiceChannelEffect as VoiceChannelEffectPayload from .ui.action_row import Components, MessageUIComponent @@ -110,17 +112,22 @@ class VoiceChannelEffect: Attributes ---------- emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]] - The emoji, for emoji reaction effects. + The emoji, for emoji reaction effects and soundboard effects. animation_type: Optional[:class:`VoiceChannelEffectAnimationType`] - The emoji animation type, for emoji reaction effects. + The emoji animation type, for emoji reaction and soundboard effects. animation_id: Optional[:class:`int`] - The emoji animation ID, for emoji reaction effects. + The emoji animation ID, for emoji reaction and soundboard effects. + sound: Optional[Union[:class:`GuildSoundboardSound`, :class:`PartialSoundboardSound`]] + The sound data, for soundboard effects. + This will be a :class:`PartialSoundboardSound` if it's a default sound + or from an external guild. """ __slots__ = ( "emoji", "animation_type", "animation_id", + "sound", ) def __init__(self, *, data: VoiceChannelEffectPayload, state: ConnectionState) -> None: @@ -138,10 +145,21 @@ def __init__(self, *, data: VoiceChannelEffectPayload, state: ConnectionState) - ) self.animation_id: Optional[int] = utils._get_as_snowflake(data, "animation_id") + self.sound: Optional[Union[GuildSoundboardSound, PartialSoundboardSound]] = None + if sound_id := utils._get_as_snowflake(data, "sound_id"): + if sound := state.get_soundboard_sound(sound_id): + self.sound = sound + else: + sound_data: PartialSoundboardSoundPayload = { + "sound_id": sound_id, + "volume": data.get("sound_volume"), # type: ignore # assume this exists if sound_id is set + } + self.sound = PartialSoundboardSound(data=sound_data, state=state) + def __repr__(self) -> str: return ( f"" + f" animation_id={self.animation_id!r} sound={self.sound!r}>" ) @@ -1916,6 +1934,37 @@ async def create_webhook( ) return Webhook.from_state(data, state=self._state) + async def send_soundboard_sound(self, sound: SoundboardSound, /) -> None: + """|coro| + + Sends a soundboard sound in this channel. + + You must have :attr:`~Permissions.speak` and :attr:`~Permissions.use_soundboard` + permissions to do this. For sounds from different guilds, you must also have + :attr:`~Permissions.use_external_sounds` permission. + Additionally, you may not be muted or deafened. + + Parameters + ---------- + sound: Union[:class:`SoundboardSound`, :class:`GuildSoundboardSound`] + The sound to send in the channel. + + Raises + ------ + Forbidden + You are not allowed to send soundboard sounds. + HTTPException + An error occurred sending the soundboard sound. + """ + if isinstance(sound, GuildSoundboardSound): + source_guild_id = sound.guild_id + else: + source_guild_id = None + + await self._state.http.send_soundboard_sound( + self.id, sound.id, source_guild_id=source_guild_id + ) + class StageChannel(disnake.abc.Messageable, VocalGuildChannel): """Represents a Discord guild stage channel. diff --git a/disnake/client.py b/disnake/client.py index 990e19c201..23cc8ae626 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -68,6 +68,7 @@ from .mentions import AllowedMentions from .object import Object from .sku import SKU +from .soundboard import GuildSoundboardSound, SoundboardSound from .stage_instance import StageInstance from .state import ConnectionState from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory @@ -571,6 +572,14 @@ def stickers(self) -> List[GuildSticker]: """ return self._connection.stickers + @property + def soundboard_sounds(self) -> List[GuildSoundboardSound]: + """List[:class:`.GuildSoundboardSound`]: The soundboard sounds that the connected client has. + + .. versionadded:: 2.10 + """ + return self._connection.soundboard_sounds + @property def cached_messages(self) -> Sequence[Message]: """Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached. @@ -1501,7 +1510,7 @@ def get_sticker(self, id: int, /) -> Optional[GuildSticker]: .. note:: - To retrieve standard stickers, use :meth:`.fetch_sticker`. + To retrieve standard stickers, use :meth:`.fetch_sticker` or :meth:`.fetch_sticker_packs`. Returns @@ -1511,6 +1520,22 @@ def get_sticker(self, id: int, /) -> Optional[GuildSticker]: """ return self._connection.get_sticker(id) + def get_soundboard_sound(self, id: int, /) -> Optional[GuildSoundboardSound]: + """Returns a guild soundboard sound with the given ID. + + .. versionadded:: 2.10 + + .. note:: + + To retrieve standard soundboard sounds, use :meth:`.fetch_default_soundboard_sounds`. + + Returns + ------- + Optional[:class:`.GuildSoundboardSound`] + The soundboard sound or ``None`` if not found. + """ + return self._connection.get_soundboard_sound(id) + def get_all_channels(self) -> Generator[GuildChannel, None, None]: """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. @@ -2357,6 +2382,26 @@ async def fetch_widget(self, guild_id: int, /) -> Widget: data = await self.http.get_widget(guild_id) return Widget(state=self._connection, data=data) + async def fetch_default_soundboard_sounds(self) -> List[SoundboardSound]: + """|coro| + + Retrieves the list of default :class:`.SoundboardSound`\\s provided by Discord. + + .. versionadded:: 2.10 + + Raises + ------ + HTTPException + Retrieving the soundboard sounds failed. + + Returns + ------- + List[:class:`.SoundboardSound`] + The default soundboard sounds. + """ + data = await self.http.get_default_soundboard_sounds() + return [SoundboardSound(data=d, state=self._connection) for d in data] + async def application_info(self) -> AppInfo: """|coro| diff --git a/disnake/enums.py b/disnake/enums.py index ddbc2c42bb..1707d0976b 100644 --- a/disnake/enums.py +++ b/disnake/enums.py @@ -403,6 +403,9 @@ class AuditLogAction(Enum): thread_update = 111 thread_delete = 112 application_command_permission_update = 121 + soundboard_sound_create = 130 + soundboard_sound_update = 131 + soundboard_sound_delete = 132 automod_rule_create = 140 automod_rule_update = 141 automod_rule_delete = 142 @@ -465,6 +468,9 @@ def category(self) -> Optional[AuditLogActionCategory]: AuditLogAction.guild_scheduled_event_update: AuditLogActionCategory.update, AuditLogAction.guild_scheduled_event_delete: AuditLogActionCategory.delete, AuditLogAction.application_command_permission_update: AuditLogActionCategory.update, + AuditLogAction.soundboard_sound_create: AuditLogActionCategory.create, + AuditLogAction.soundboard_sound_update: AuditLogActionCategory.update, + AuditLogAction.soundboard_sound_delete: AuditLogActionCategory.delete, AuditLogAction.automod_rule_create: AuditLogActionCategory.create, AuditLogAction.automod_rule_update: AuditLogActionCategory.update, AuditLogAction.automod_rule_delete: AuditLogActionCategory.delete, @@ -1065,6 +1071,12 @@ class Event(Enum): """Called when a `Guild` updates its stickers. Represents the :func:`on_guild_stickers_update` event. """ + guild_soundboard_sounds_update = "guild_soundboard_sounds_update" + """Called when a `Guild` updates its soundboard sounds. + Represents the :func:`on_guild_soundboard_sounds_update` event. + + .. versionadded:: 2.10 + """ guild_integrations_update = "guild_integrations_update" """Called whenever an integration is created, modified, or removed from a guild. Represents the :func:`on_guild_integrations_update` event. @@ -1288,7 +1300,8 @@ class Event(Enum): """ raw_presence_update = "raw_presence_update" """Called when a user's presence changes regardless of the state of the internal member cache. - Represents the :func:`on_raw_presence_update` event.""" + Represents the :func:`on_raw_presence_update` event. + """ raw_reaction_add = "raw_reaction_add" """Called when a message has a reaction added regardless of the state of the internal message cache. Represents the :func:`on_raw_reaction_add` event. @@ -1315,13 +1328,22 @@ class Event(Enum): """ entitlement_create = "entitlement_create" """Called when a user subscribes to an SKU, creating a new :class:`Entitlement`. - Represents the :func:`on_entitlement_create` event.""" + Represents the :func:`on_entitlement_create` event. + + .. versionadded:: 2.10 + """ entitlement_update = "entitlement_update" """Called when a user's subscription renews. - Represents the :func:`on_entitlement_update` event.""" + Represents the :func:`on_entitlement_update` event. + + .. versionadded:: 2.10 + """ entitlement_delete = "entitlement_delete" """Called when a user's entitlement is deleted. - Represents the :func:`on_entitlement_delete` event.""" + Represents the :func:`on_entitlement_delete` event. + + .. versionadded:: 2.10 + """ # ext.commands events command = "command" """Called when a command is found and is about to be invoked. diff --git a/disnake/ext/commands/converter.py b/disnake/ext/commands/converter.py index 79570b9f8e..4288a8b85a 100644 --- a/disnake/ext/commands/converter.py +++ b/disnake/ext/commands/converter.py @@ -40,6 +40,7 @@ EmojiNotFound, GuildNotFound, GuildScheduledEventNotFound, + GuildSoundboardSoundNotFound, GuildStickerNotFound, MemberNotFound, MessageNotFound, @@ -82,6 +83,7 @@ "EmojiConverter", "PartialEmojiConverter", "GuildStickerConverter", + "GuildSoundboardSoundConverter", "PermissionsConverter", "GuildScheduledEventConverter", "clean_content", @@ -944,6 +946,43 @@ async def convert(self, ctx: AnyContext, argument: str) -> disnake.GuildSticker: return result +class GuildSoundboardSoundConverter(IDConverter[disnake.GuildSoundboardSound]): + """Converts to a :class:`~disnake.GuildSoundboardSound`. + + All lookups are done for the local guild first, if available. If that lookup + fails, then it checks the client's global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID + 2. Lookup by name + + .. versionadded:: 2.10 + """ + + async def convert(self, ctx: AnyContext, argument: str) -> disnake.GuildSoundboardSound: + match = self._get_id_match(argument) + result = None + bot: disnake.Client = ctx.bot + guild = ctx.guild + + if match is None: + # Try to get the sound by name. Try local guild first. + if guild: + result = _utils_get(guild.soundboard_sounds, name=argument) + + if result is None: + result = _utils_get(bot.soundboard_sounds, name=argument) + else: + # Try to look up sound by id. + result = bot.get_soundboard_sound(int(match.group(1))) + + if result is None: + raise GuildSoundboardSoundNotFound(argument) + + return result + + class PermissionsConverter(Converter[disnake.Permissions]): """Converts to a :class:`~disnake.Permissions`. @@ -1212,6 +1251,7 @@ def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: disnake.Thread: ThreadConverter, disnake.abc.GuildChannel: GuildChannelConverter, disnake.GuildSticker: GuildStickerConverter, + disnake.GuildSoundboardSound: GuildSoundboardSoundConverter, disnake.Permissions: PermissionsConverter, disnake.GuildScheduledEvent: GuildScheduledEventConverter, } diff --git a/disnake/ext/commands/errors.py b/disnake/ext/commands/errors.py index cfa4c12f03..dde960db2b 100644 --- a/disnake/ext/commands/errors.py +++ b/disnake/ext/commands/errors.py @@ -49,6 +49,7 @@ "BadInviteArgument", "EmojiNotFound", "GuildStickerNotFound", + "GuildSoundboardSoundNotFound", "GuildScheduledEventNotFound", "PartialEmojiConversionFailure", "BadBoolArgument", @@ -499,6 +500,24 @@ def __init__(self, argument: str) -> None: super().__init__(f'Sticker "{argument}" not found.') +class GuildSoundboardSoundNotFound(BadArgument): + """Exception raised when the bot can not find the soundboard sound. + + This inherits from :exc:`BadArgument` + + .. versionadded:: 2.10 + + Attributes + ---------- + argument: :class:`str` + The soundboard sound supplied by the caller that was not found + """ + + def __init__(self, argument: str) -> None: + self.argument: str = argument + super().__init__(f'Soundboard sound "{argument}" not found.') + + class GuildScheduledEventNotFound(BadArgument): """Exception raised when the bot cannot find the scheduled event. diff --git a/disnake/flags.py b/disnake/flags.py index 5117f7b80b..53a72c5cbf 100644 --- a/disnake/flags.py +++ b/disnake/flags.py @@ -1060,6 +1060,7 @@ def __init__( dm_typing: bool = ..., emojis: bool = ..., emojis_and_stickers: bool = ..., + expressions: bool = ..., guild_messages: bool = ..., guild_polls: bool = ..., guild_reactions: bool = ..., @@ -1215,35 +1216,52 @@ def bans(self): return 1 << 2 @flag_value - def emojis(self): - """:class:`bool`: Alias of :attr:`.emojis_and_stickers`. - - .. versionchanged:: 2.0 - Changed to an alias. - """ - return 1 << 3 - - @alias_flag_value - def emojis_and_stickers(self): - """:class:`bool`: Whether guild emoji and sticker related events are enabled. + def expressions(self): + """:class:`bool`: Whether events related to guild emojis, stickers, and + soundboard sounds are enabled. - .. versionadded:: 2.0 + .. versionadded:: 2.10 This corresponds to the following events: - :func:`on_guild_emojis_update` - :func:`on_guild_stickers_update` + - :func:`on_guild_soundboard_sounds_update` This also corresponds to the following attributes and classes in terms of cache: - :class:`Emoji` - :class:`GuildSticker` + - :class:`GuildSoundboardSound` - :meth:`Client.get_emoji` - :meth:`Client.get_sticker` + - :meth:`Client.get_soundboard_sound` - :meth:`Client.emojis` - :meth:`Client.stickers` + - :meth:`Client.soundboard_sounds` - :attr:`Guild.emojis` - :attr:`Guild.stickers` + - :attr:`Guild.soundboard_sounds` + """ + return 1 << 3 + + @alias_flag_value + def emojis_and_stickers(self): + """:class:`bool`: Alias of :attr:`.expressions`. + + .. versionadded:: 2.0 + + .. versionchanged:: 2.10 + Changed to an alias. + """ + return 1 << 3 + + @alias_flag_value + def emojis(self): + """:class:`bool`: Alias of :attr:`.expressions`. + + .. versionchanged:: 2.0 + Changed to an alias. """ return 1 << 3 diff --git a/disnake/guild.py b/disnake/guild.py index a0609063bf..1b436e4625 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -74,6 +74,7 @@ from .partial_emoji import PartialEmoji from .permissions import PermissionOverwrite from .role import Role +from .soundboard import GuildSoundboardSound from .stage_instance import StageInstance from .sticker import GuildSticker from .threads import Thread, ThreadMember @@ -128,6 +129,7 @@ class _GuildLimit(NamedTuple): stickers: int bitrate: float filesize: int + sounds: int class Guild(Hashable): @@ -164,6 +166,11 @@ class Guild(Hashable): .. versionadded:: 2.0 + soundboard_sounds: Tuple[:class:`GuildSoundboardSound`, ...] + All soundboard sounds that the guild owns. + + .. versionadded:: 2.10 + afk_timeout: :class:`int` The timeout to get sent to the AFK channel. afk_channel: Optional[:class:`VoiceChannel`] @@ -232,6 +239,7 @@ class Guild(Hashable): - ``LINKED_TO_HUB``: Guild is linked to a student hub. - ``MEMBER_VERIFICATION_GATE_ENABLED``: Guild has Membership Screening enabled. - ``MORE_EMOJI``: Guild has increased custom emoji slots. + - ``MORE_SOUNDBOARD``: Guild has increased custom soundboard slots. - ``MORE_STICKERS``: Guild has increased custom sticker slots. - ``NEWS``: Guild can create news channels. - ``NEW_THREAD_PERMISSIONS``: Guild is using the new thread permission system. @@ -243,6 +251,7 @@ class Guild(Hashable): - ``ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE``: Guild has role subscriptions that can be purchased. - ``ROLE_SUBSCRIPTIONS_ENABLED``: Guild has enabled role subscriptions. - ``SEVEN_DAY_THREAD_ARCHIVE``: Guild has access to the seven day archive time for threads (no longer has any effect). + - ``SOUNDBOARD``: Guild has created soundboard sounds. - ``TEXT_IN_VOICE_ENABLED``: Guild has text in voice channels enabled (no longer has any effect). - ``THREE_DAY_THREAD_ARCHIVE``: Guild has access to the three day archive time for threads (no longer has any effect). - ``THREADS_ENABLED``: Guild has access to threads (no longer has any effect). @@ -321,6 +330,7 @@ class Guild(Hashable): "mfa_level", "emojis", "stickers", + "soundboard_sounds", "features", "verification_level", "explicit_content_filter", @@ -363,11 +373,11 @@ class Guild(Hashable): ) _PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = { - None: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400), - 0: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400), - 1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=26214400), - 2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52428800), - 3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600), + None: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400, sounds=8), + 0: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400, sounds=8), + 1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=26214400, sounds=24), + 2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52428800, sounds=36), + 3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600, sounds=48), } def __init__(self, *, data: GuildPayload, state: ConnectionState) -> None: @@ -554,6 +564,9 @@ def _from_data(self, guild: GuildPayload) -> None: self.stickers: Tuple[GuildSticker, ...] = tuple( state.store_sticker(self, d) for d in guild.get("stickers", []) ) + self.soundboard_sounds: Tuple[GuildSoundboardSound, ...] = tuple( + state.store_soundboard_sound(self, d) for d in guild.get("soundboard_sounds", []) + ) self.features: List[GuildFeature] = guild.get("features", []) self._splash: Optional[str] = guild.get("splash") self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, "system_channel_id") @@ -926,6 +939,15 @@ def filesize_limit(self) -> int: """:class:`int`: The maximum number of bytes files can have when uploaded to this guild.""" return self._PREMIUM_GUILD_LIMITS[self.premium_tier].filesize + @property + def soundboard_limit(self) -> int: + """:class:`int`: The maximum number of soundboard slots this guild has. + + .. versionadded:: 2.10 + """ + more_soundboard = 96 if "MORE_SOUNDBOARD" in self.features else 0 + return max(more_soundboard, self._PREMIUM_GUILD_LIMITS[self.premium_tier].sounds) + @property def members(self) -> List[Member]: """List[:class:`Member`]: A list of members that belong to this guild.""" @@ -5033,6 +5055,118 @@ async def onboarding(self) -> Onboarding: data = await self._state.http.get_guild_onboarding(self.id) return Onboarding(data=data, guild=self) + async def create_soundboard_sound( + self, + *, + name: str, + sound: AssetBytes, + volume: Optional[float] = None, + emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, + reason: Optional[str] = None, + ) -> GuildSoundboardSound: + """|coro| + + Creates a :class:`GuildSoundboardSound` for the guild. + + You must have :attr:`~Permissions.create_guild_expressions` permission to + do this. + + .. versionadded:: 2.10 + + Parameters + ---------- + name: :class:`str` + The sound name. Must be at least 2 characters. + sound: |resource_type| + The sound data. + Only MP3 and Ogg formats are supported. + volume: Optional[:class:`float`] + The sound's volume (from ``0.0`` to ``1.0``). + Defaults to ``1.0``. + emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]] + The sound's emoji, if any. + reason: Optional[:class:`str`] + The reason for creating this sound. Shows up on the audit log. + + Raises + ------ + Forbidden + You are not allowed to create soundboard sounds. + HTTPException + An error occurred creating a soundboard sound. + + Returns + ------- + :class:`GuildSoundboardSound` + The newly created soundboard sound. + """ + sound_data = await utils._assetbytes_to_base64_data(sound) + emoji_name, emoji_id = PartialEmoji._emoji_to_name_id(emoji) + + data = await self._state.http.create_guild_soundboard_sound( + self.id, + name=name, + sound=sound_data, + volume=volume, + emoji_id=emoji_id, + emoji_name=emoji_name, + reason=reason, + ) + return GuildSoundboardSound(data=data, state=self._state, guild_id=self.id) + + async def fetch_soundboard_sound(self, sound_id: int, /) -> GuildSoundboardSound: + """|coro| + + Retrieves a soundboard sound from the guild. + See also :func:`~Guild.fetch_soundboard_sounds`. + + .. note:: + + This method is an API call. For general usage, consider :attr:`soundboard_sounds` instead. + + .. versionadded:: 2.10 + + Raises + ------ + NotFound + A soundboard sound with the provided ID does not exist in the guild. + HTTPException + Retrieving the soundboard sound failed. + + Returns + ------- + :class:`GuildSoundboardSound` + The soundboard sound. + """ + data = await self._state.http.get_guild_soundboard_sound(self.id, sound_id) + return GuildSoundboardSound(data=data, state=self._state, guild_id=self.id) + + async def fetch_soundboard_sounds(self) -> List[GuildSoundboardSound]: + """|coro| + + Retrieves all :class:`GuildSoundboardSound`\\s that the guild has. + + .. note:: + + This method is an API call. For general usage, consider :attr:`soundboard_sounds` instead. + + .. versionadded:: 2.10 + + Raises + ------ + HTTPException + Retrieving the soundboard sounds failed. + + Returns + ------- + List[:class:`GuildSoundboardSound`] + All soundboard sounds that the guild has. + """ + data = await self._state.http.get_guild_soundboard_sounds(self.id) + return [ + GuildSoundboardSound(data=d, state=self._state, guild_id=self.id) for d in data["items"] + ] + PlaceholderID = NewType("PlaceholderID", int) diff --git a/disnake/http.py b/disnake/http.py index 89dfe30724..8304d780d4 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -73,6 +73,7 @@ poll, role, sku, + soundboard, sticker, template, threads, @@ -1654,7 +1655,7 @@ def create_guild_sticker( initial_bytes = file.fp.read(16) try: - mime_type = utils._get_mime_type_for_image(initial_bytes) + mime_type = utils._get_mime_type_for_data(initial_bytes) except ValueError: if initial_bytes.startswith(b"{"): mime_type = "application/json" @@ -2816,6 +2817,124 @@ def get_application_command_permissions( ) return self.request(r) + # Soundboard + + def get_default_soundboard_sounds(self) -> Response[List[soundboard.SoundboardSound]]: + return self.request(Route("GET", "/soundboard-default-sounds")) + + def get_guild_soundboard_sound( + self, guild_id: Snowflake, sound_id: Snowflake + ) -> Response[soundboard.GuildSoundboardSound]: + return self.request( + Route( + "GET", + "/guilds/{guild_id}/soundboard-sounds/{sound_id}", + guild_id=guild_id, + sound_id=sound_id, + ) + ) + + def get_guild_soundboard_sounds( + self, guild_id: Snowflake + ) -> Response[soundboard.ListGuildSoundboardSounds]: + return self.request( + Route( + "GET", + "/guilds/{guild_id}/soundboard-sounds", + guild_id=guild_id, + ) + ) + + def create_guild_soundboard_sound( + self, + guild_id: Snowflake, + *, + name: str, + sound: Optional[str], + volume: Optional[float] = None, + emoji_id: Optional[Snowflake] = None, + emoji_name: Optional[str] = None, + reason: Optional[str] = None, + ) -> Response[soundboard.GuildSoundboardSound]: + payload: Dict[str, Any] = { + "name": name, + "sound": sound, + } + + if volume is not None: + payload["volume"] = volume + if emoji_id is not None: + payload["emoji_id"] = emoji_id + if emoji_name is not None: + payload["emoji_name"] = emoji_name + + return self.request( + Route("POST", "/guilds/{guild_id}/soundboard-sounds", guild_id=guild_id), + json=payload, + reason=reason, + ) + + def edit_guild_soundboard_sound( + self, + guild_id: Snowflake, + sound_id: Snowflake, + *, + reason: Optional[str] = None, + **fields: Any, + ) -> Response[soundboard.GuildSoundboardSound]: + valid_keys = ( + "name", + "volume", + "emoji_id", + "emoji_name", + ) + payload = {k: v for k, v in fields.items() if k in valid_keys} + return self.request( + Route( + "PATCH", + "/guilds/{guild_id}/soundboard-sounds/{sound_id}", + guild_id=guild_id, + sound_id=sound_id, + ), + json=payload, + reason=reason, + ) + + def delete_guild_soundboard_sound( + self, + guild_id: Snowflake, + sound_id: Snowflake, + *, + reason: Optional[str] = None, + ) -> Response[None]: + return self.request( + Route( + "DELETE", + "/guilds/{guild_id}/soundboard-sounds/{sound_id}", + guild_id=guild_id, + sound_id=sound_id, + ), + reason=reason, + ) + + def send_soundboard_sound( + self, + channel_id: Snowflake, + sound_id: Snowflake, + *, + source_guild_id: Optional[Snowflake] = None, + ) -> Response[None]: + payload: Dict[str, Any] = { + "sound_id": sound_id, + } + if source_guild_id is not None: + payload["source_guild_id"] = source_guild_id + + return self.request( + Route("POST", "/channels/{channel_id}/send-soundboard-sound", channel_id=channel_id), + json=payload, + ) + # Misc def get_voice_regions(self) -> Response[List[voice.VoiceRegion]]: diff --git a/disnake/soundboard.py b/disnake/soundboard.py new file mode 100644 index 0000000000..0d1d701dac --- /dev/null +++ b/disnake/soundboard.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from .asset import Asset, AssetMixin +from .mixins import Hashable +from .partial_emoji import PartialEmoji +from .utils import MISSING, _get_as_snowflake, snowflake_time + +if TYPE_CHECKING: + from .emoji import Emoji + from .guild import Guild + from .state import ConnectionState + from .types.soundboard import ( + GuildSoundboardSound as GuildSoundboardSoundPayload, + PartialSoundboardSound as PartialSoundboardSoundPayload, + SoundboardSound as SoundboardSoundPayload, + ) + from .user import User + + +__all__ = ( + "PartialSoundboardSound", + "SoundboardSound", + "GuildSoundboardSound", +) + + +class PartialSoundboardSound(Hashable, AssetMixin): + """Represents a partial soundboard sound. + + Used for sounds in :class:`VoiceChannelEffect`\\s, + and as the base for full :class:`SoundboardSound`/:class:`GuildSoundboardSound` objects. + + .. versionadded:: 2.10 + + .. collapse:: operations + + .. describe:: x == y + + Checks if two soundboard sounds are equal. + + .. describe:: x != y + + Checks if two soundboard sounds are not equal. + + .. describe:: hash(x) + + Returns the soundboard sounds' hash. + + Attributes + ---------- + id: :class:`int` + The sound's ID. + volume: :class:`float` + The sound's volume (from ``0.0`` to ``1.0``). + """ + + __slots__ = ( + "id", + "volume", + ) + + def __init__( + self, + *, + data: PartialSoundboardSoundPayload, + state: Optional[ConnectionState] = None, + ) -> None: + self._state = state + self.id: int = int(data["sound_id"]) + self.volume: float = data["volume"] + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id!r}>" + + @property + def created_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: Returns the sound's creation time in UTC. + Can be ``None`` if this is a default sound. + """ + if self.is_default(): + return None + return snowflake_time(self.id) + + @property + def url(self) -> str: + """:class:`str`: The url for the sound file.""" + return f"{Asset.BASE}/soundboard-sounds/{self.id}" + + def is_default(self) -> bool: + """Whether the sound is a default sound provided by Discord. + + :return type: :class:`bool` + """ + # default sounds have IDs starting from 1, i.e. tiny numbers compared to real snowflakes. + # assume any "snowflake" with a zero timestamp is for a default sound + return (self.id >> 22) == 0 + + +class SoundboardSound(PartialSoundboardSound): + """Represents a soundboard sound. + + .. versionadded:: 2.10 + + .. collapse:: operations + + .. describe:: x == y + + Checks if two soundboard sounds are equal. + + .. describe:: x != y + + Checks if two soundboard sounds are not equal. + + .. describe:: hash(x) + + Returns the soundboard sounds' hash. + + Attributes + ---------- + id: :class:`int` + The sound's ID. + volume: :class:`float` + The sound's volume (from ``0.0`` to ``1.0``). + name: :class:`str` + The sound's name. + emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]] + The sound's emoji, if any. + Due to a Discord limitation, this will have an empty + :attr:`~PartialEmoji.name` if it is a custom :class:`PartialEmoji`. + available: :class:`bool` + Whether this sound is available for use. + """ + + __slots__ = ("name", "emoji", "available") + + _state: ConnectionState + + def __init__( + self, + *, + data: SoundboardSoundPayload, + state: ConnectionState, + ) -> None: + super().__init__(data=data, state=state) + + self.name: str = data["name"] + self.emoji: Optional[Union[Emoji, PartialEmoji]] = self._state._get_emoji_from_fields( + name=data.get("emoji_name"), + id=_get_as_snowflake(data, "emoji_id"), + ) + self.available: bool = data.get("available", True) + + def __repr__(self) -> str: + return f"" + + +class GuildSoundboardSound(SoundboardSound): + """Represents a soundboard sound that belongs to a guild. + + .. versionadded:: 2.10 + + .. collapse:: operations + + .. describe:: x == y + + Checks if two soundboard sounds are equal. + + .. describe:: x != y + + Checks if two soundboard sounds are not equal. + + .. describe:: hash(x) + + Returns the soundboard sounds' hash. + + Attributes + ---------- + id: :class:`int` + The sound's ID. + volume: :class:`float` + The sound's volume (from ``0.0`` to ``1.0``). + name: :class:`str` + The sound's name. + emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]] + The sound's emoji, if any. + Due to a Discord limitation, this will have an empty + :attr:`~PartialEmoji.name` if it is a custom :class:`PartialEmoji`. + guild_id: :class:`int` + The ID of the guild this sound belongs to. + available: :class:`bool` + Whether this sound is available for use. + user: Optional[:class:`User`] + The user that created this sound. This can only be retrieved using + :meth:`Guild.fetch_soundboard_sound`/:meth:`Guild.fetch_soundboard_sounds` while + having the :attr:`~Permissions.create_guild_expressions` or + :attr:`~Permissions.manage_guild_expressions` permission. + """ + + __slots__ = ("guild_id", "user") + + def __init__( + self, + *, + data: GuildSoundboardSoundPayload, + state: ConnectionState, + # `guild_id` isn't sent over REST, so we manually keep track of it + guild_id: int, + ) -> None: + super().__init__(data=data, state=state) + + self.guild_id: int = guild_id + self.user: Optional[User] = ( + state.store_user(user_data) if (user_data := data.get("user")) is not None else None + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + @property + def guild(self) -> Guild: + """:class:`Guild`: The guild that this sound is from.""" + # this will most likely never return None + return self._state._get_guild(self.guild_id) # type: ignore + + async def edit( + self, + *, + name: str = MISSING, + volume: float = MISSING, + emoji: Optional[Union[str, Emoji, PartialEmoji]] = MISSING, + reason: Optional[str] = None, + ) -> GuildSoundboardSound: + """|coro| + + Edits a :class:`GuildSoundboardSound` for the guild. + + You must have :attr:`~Permissions.manage_guild_expressions` permission to + do this. + If this sound was created by you, :attr:`~Permissions.create_guild_expressions` + permission is also sufficient. + + All fields are optional. + + Parameters + ---------- + name: :class:`str` + The sounds's new name. Must be at least 2 characters. + volume: :class:`float` + The sound's new volume (from ``0.0`` to ``1.0``). + emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]] + The sound's new emoji. Can be ``None``. + reason: Optional[:class:`str`] + The reason for editing this sound. Shows up on the audit log. + + Raises + ------ + Forbidden + You are not allowed to edit this soundboard sound. + HTTPException + An error occurred editing the soundboard sound. + + Returns + ------- + :class:`GuildSoundboardSound` + The newly modified soundboard sound. + """ + payload: Dict[str, Any] = {} + + if name is not MISSING: + payload["name"] = name + if volume is not MISSING: + payload["volume"] = volume + if emoji is not MISSING: + emoji_name, emoji_id = PartialEmoji._emoji_to_name_id(emoji) + payload["emoji_name"] = emoji_name + payload["emoji_id"] = emoji_id + + data = await self._state.http.edit_guild_soundboard_sound( + self.guild_id, self.id, reason=reason, **payload + ) + return GuildSoundboardSound(data=data, state=self._state, guild_id=self.guild_id) + + async def delete(self, *, reason: Optional[str] = None) -> None: + """|coro| + + Deletes the :class:`GuildSoundboardSound` from the guild. + + You must have :attr:`~Permissions.manage_guild_expressions` permission to + do this. + If this sound was created by you, :attr:`~Permissions.create_guild_expressions` + permission is also sufficient. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason for deleting this sound. Shows up on the audit log. + + Raises + ------ + Forbidden + You are not allowed to delete this soundboard sound. + HTTPException + An error occurred deleting the soundboard sound. + """ + await self._state.http.delete_guild_soundboard_sound(self.guild_id, self.id, reason=reason) diff --git a/disnake/state.py b/disnake/state.py index bb9b99a874..54facfc626 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -18,6 +18,7 @@ Coroutine, Deque, Dict, + Generic, List, Literal, Optional, @@ -83,6 +84,7 @@ RawVoiceChannelEffectEvent, ) from .role import Role +from .soundboard import GuildSoundboardSound from .stage_instance import StageInstance from .sticker import GuildSticker from .threads import Thread, ThreadMember @@ -106,17 +108,44 @@ from .types.guild import Guild as GuildPayload, UnavailableGuild as UnavailableGuildPayload from .types.interactions import InteractionChannel as InteractionChannelPayload from .types.message import Message as MessagePayload + from .types.soundboard import GuildSoundboardSound as GuildSoundboardSoundPayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.user import User as UserPayload from .types.webhook import Webhook as WebhookPayload from .voice_client import VoiceProtocol - T = TypeVar("T") Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel] PartialChannel = Union[Channel, PartialMessageable] +T = TypeVar("T") -class ChunkRequest: + +class AsyncRequest(Generic[T]): + def __init__(self, guild_id: int, loop: asyncio.AbstractEventLoop) -> None: + self.guild_id: int = guild_id + self.loop: asyncio.AbstractEventLoop = loop + self.waiters: List[asyncio.Future[T]] = [] + + async def wait(self) -> T: + future: asyncio.Future[T] = self.loop.create_future() + self.waiters.append(future) + try: + return await future + finally: + self.waiters.remove(future) + + def get_future(self) -> asyncio.Future[T]: + future: asyncio.Future[T] = self.loop.create_future() + self.waiters.append(future) + return future + + def set_result(self, result: T) -> None: + for future in self.waiters: + if not future.done(): + future.set_result(result) + + +class ChunkRequest(AsyncRequest[List[Member]]): def __init__( self, guild_id: int, @@ -125,13 +154,11 @@ def __init__( *, cache: bool = True, ) -> None: - self.guild_id: int = guild_id + super().__init__(guild_id=guild_id, loop=loop) self.resolver: Callable[[int], Any] = resolver - self.loop: asyncio.AbstractEventLoop = loop self.cache: bool = cache self.nonce: str = os.urandom(16).hex() self.buffer: List[Member] = [] - self.waiters: List[asyncio.Future[List[Member]]] = [] def add_members(self, members: List[Member]) -> None: self.buffer.extend(members) @@ -145,23 +172,8 @@ def add_members(self, members: List[Member]) -> None: if existing is None or existing.joined_at is None: guild._add_member(member) - async def wait(self) -> List[Member]: - future = self.loop.create_future() - self.waiters.append(future) - try: - return await future - finally: - self.waiters.remove(future) - - def get_future(self) -> asyncio.Future[List[Member]]: - future = self.loop.create_future() - self.waiters.append(future) - return future - def done(self) -> None: - for future in self.waiters: - if not future.done(): - future.set_result(self.buffer) + self.set_result(self.buffer) _log = logging.getLogger(__name__) @@ -300,6 +312,7 @@ def clear( self._users: weakref.WeakValueDictionary[int, User] = weakref.WeakValueDictionary() self._emojis: Dict[int, Emoji] = {} self._stickers: Dict[int, GuildSticker] = {} + self._soundboard_sounds: Dict[int, GuildSoundboardSound] = {} self._guilds: Dict[int, Guild] = {} if application_commands: @@ -410,6 +423,15 @@ def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker + def store_soundboard_sound( + self, guild: Guild, data: GuildSoundboardSoundPayload + ) -> GuildSoundboardSound: + sound_id = int(data["sound_id"]) + self._soundboard_sounds[sound_id] = sound = GuildSoundboardSound( + state=self, data=data, guild_id=guild.id + ) + return sound + def store_view(self, view: View, message_id: Optional[int] = None) -> None: self._view_store.add_view(view, message_id) @@ -445,6 +467,9 @@ def _remove_guild(self, guild: Guild) -> None: for sticker in guild.stickers: self._stickers.pop(sticker.id, None) + for sound in guild.soundboard_sounds: + self._soundboard_sounds.pop(sound.id, None) + del guild def _get_global_application_command( @@ -520,6 +545,10 @@ def emojis(self) -> List[Emoji]: def stickers(self) -> List[GuildSticker]: return list(self._stickers.values()) + @property + def soundboard_sounds(self) -> List[GuildSoundboardSound]: + return list(self._soundboard_sounds.values()) + def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]: # the keys of self._emojis are ints return self._emojis.get(emoji_id) # type: ignore @@ -528,6 +557,10 @@ def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]: # the keys of self._stickers are ints return self._stickers.get(sticker_id) # type: ignore + def get_soundboard_sound(self, sound_id: Optional[int]) -> Optional[GuildSoundboardSound]: + # the keys of self._soundboard_sounds are ints + return self._soundboard_sounds.get(sound_id) # type: ignore + @property def private_channels(self) -> List[PrivateChannel]: return list(self._private_channels.values()) @@ -1390,8 +1423,8 @@ def parse_guild_stickers_update(self, data: gateway.GuildStickersUpdateEvent) -> return before_stickers = guild.stickers - for emoji in before_stickers: - self._stickers.pop(emoji.id, None) + for sticker in before_stickers: + self._stickers.pop(sticker.id, None) guild.stickers = tuple(self.store_sticker(guild, d) for d in data["stickers"]) self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers) @@ -1974,6 +2007,107 @@ def parse_entitlement_delete(self, data: gateway.EntitlementDelete) -> None: entitlement = Entitlement(data=data, state=self) self.dispatch("entitlement_delete", entitlement) + def parse_guild_soundboard_sound_create(self, data: gateway.GuildSoundboardSoundCreate) -> None: + guild_id = utils._get_as_snowflake(data, "guild_id") + guild = self._get_guild(guild_id) + if guild is None: + _log.debug( + "GUILD_SOUNDBOARD_SOUND_CREATE referencing unknown guild ID: %s. Discarding.", + guild_id, + ) + return + + sound = self.store_soundboard_sound(guild, data) + + # since both single-target `SOUND_CREATE`/`_UPDATE`/`_DELETE`s and a generic `SOUNDS_UPDATE` + # exist, turn these events into synthetic `SOUNDS_UPDATE`s + self._handle_soundboard_update( + guild, + # append new sound + guild.soundboard_sounds + (sound,), + ) + + def parse_guild_soundboard_sound_update(self, data: gateway.GuildSoundboardSoundUpdate) -> None: + guild_id = utils._get_as_snowflake(data, "guild_id") + guild = self._get_guild(guild_id) + if guild is None: + _log.debug( + "GUILD_SOUNDBOARD_SOUND_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) + return + + sound_id = int(data["sound_id"]) + sound = self.get_soundboard_sound(sound_id) + if sound is None: + _log.debug( + "GUILD_SOUNDBOARD_SOUND_UPDATE referencing unknown sound ID: %s. Discarding.", + sound_id, + ) + return + + self._soundboard_sounds.pop(sound.id, None) + new_sound = self.store_soundboard_sound(guild, data) + + self._handle_soundboard_update( + guild, + # replace sound in tuple at same position + tuple((new_sound if s.id == sound.id else s) for s in guild.soundboard_sounds), + ) + + def parse_guild_soundboard_sound_delete(self, data: gateway.GuildSoundboardSoundDelete) -> None: + guild = self._get_guild(int(data["guild_id"])) + if guild is None: + _log.debug( + "GUILD_SOUNDBOARD_SOUND_DELETE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + return + + sound_id = int(data["sound_id"]) + sound = self.get_soundboard_sound(sound_id) + if sound is None: + _log.debug( + "GUILD_SOUNDBOARD_SOUND_UPDATE referencing unknown sound ID: %s. Discarding.", + sound_id, + ) + return + + self._soundboard_sounds.pop(sound.id, None) + + self._handle_soundboard_update( + guild, + # remove sound from tuple + tuple(s for s in guild.soundboard_sounds if s.id != sound.id), + ) + + def parse_guild_soundboard_sounds_update( + self, data: gateway.GuildSoundboardSoundsUpdate + ) -> None: + guild = self._get_guild(int(data["guild_id"])) + if guild is None: + _log.debug( + "GUILD_SOUNDBOARD_SOUNDS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + return + + for sound in guild.soundboard_sounds: + self._soundboard_sounds.pop(sound.id, None) + + self._handle_soundboard_update( + guild, + tuple(self.store_soundboard_sound(guild, d) for d in data["soundboard_sounds"]), + ) + + def _handle_soundboard_update( + self, guild: Guild, new_sounds: Tuple[GuildSoundboardSound, ...] + ) -> None: + before_sounds = guild.soundboard_sounds + guild.soundboard_sounds = new_sounds + + self.dispatch("guild_soundboard_sounds_update", guild, before_sounds, new_sounds) + def _get_reaction_user( self, channel: MessageableChannel, user_id: int ) -> Optional[Union[User, Member]]: diff --git a/disnake/types/audit_log.py b/disnake/types/audit_log.py index f9640b3ad9..f5306a7d91 100644 --- a/disnake/types/audit_log.py +++ b/disnake/types/audit_log.py @@ -79,6 +79,9 @@ 111, 112, 121, + 130, + 131, + 132, 140, 141, 142, diff --git a/disnake/types/gateway.py b/disnake/types/gateway.py index 7f86cd5959..0aabbbf47c 100644 --- a/disnake/types/gateway.py +++ b/disnake/types/gateway.py @@ -22,6 +22,7 @@ from .message import Message from .role import Role from .snowflake import Snowflake, SnowflakeList +from .soundboard import GuildSoundboardSound from .sticker import GuildSticker from .threads import Thread, ThreadMember, ThreadMemberWithPresence, ThreadType from .user import AvatarDecorationData, User @@ -666,3 +667,23 @@ class AutoModerationActionExecutionEvent(TypedDict): # https://discord.com/developers/docs/monetization/entitlements#deleted-entitlement EntitlementDelete = Entitlement + + +# https://discord.com/developers/docs/topics/gateway-events#guild-soundboard-sound-create +GuildSoundboardSoundCreate = GuildSoundboardSound + + +# https://discord.com/developers/docs/topics/gateway-events#guild-soundboard-sound-update +GuildSoundboardSoundUpdate = GuildSoundboardSound + + +# https://discord.com/developers/docs/topics/gateway-events#guild-soundboard-sound-delete +class GuildSoundboardSoundDelete(TypedDict): + guild_id: Snowflake + sound_id: Snowflake + + +# https://discord.com/developers/docs/topics/gateway-events#guild-soundboard-sounds-update +class GuildSoundboardSoundsUpdate(TypedDict): + guild_id: Snowflake + soundboard_sounds: List[GuildSoundboardSound] diff --git a/disnake/types/guild.py b/disnake/types/guild.py index 76d8d2f6b5..6b53b8e1f3 100644 --- a/disnake/types/guild.py +++ b/disnake/types/guild.py @@ -11,6 +11,7 @@ from .member import Member from .role import CreateRole, Role from .snowflake import Snowflake +from .soundboard import GuildSoundboardSound from .sticker import GuildSticker from .threads import Thread from .user import User @@ -61,6 +62,7 @@ class UnavailableGuild(TypedDict): "MEMBER_PROFILES", # not sure what this does, if anything "MEMBER_VERIFICATION_GATE_ENABLED", "MORE_EMOJI", + "MORE_SOUNDBOARD", "MORE_STICKERS", "NEWS", "NEW_THREAD_PERMISSIONS", # deprecated @@ -73,6 +75,7 @@ class UnavailableGuild(TypedDict): "ROLE_SUBSCRIPTIONS_AVAILABLE_FOR_PURCHASE", "ROLE_SUBSCRIPTIONS_ENABLED", "SEVEN_DAY_THREAD_ARCHIVE", # deprecated + "SOUNDBOARD", "TEXT_IN_VOICE_ENABLED", # deprecated "THREADS_ENABLED", # deprecated "THREE_DAY_THREAD_ARCHIVE", # deprecated @@ -147,6 +150,7 @@ class Guild(_BaseGuildPreview): presences: NotRequired[List[PartialPresenceUpdate]] stage_instances: NotRequired[List[StageInstance]] guild_scheduled_events: NotRequired[List[GuildScheduledEvent]] + soundboard_sounds: NotRequired[List[GuildSoundboardSound]] class InviteGuild(Guild, total=False): diff --git a/disnake/types/soundboard.py b/disnake/types/soundboard.py new file mode 100644 index 0000000000..333d7033b3 --- /dev/null +++ b/disnake/types/soundboard.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from typing import List, Optional, TypedDict + +from typing_extensions import NotRequired + +from .snowflake import Snowflake +from .user import User + + +class PartialSoundboardSound(TypedDict): + sound_id: Snowflake + volume: float + + +class SoundboardSound(PartialSoundboardSound): + name: str + emoji_id: Optional[Snowflake] + emoji_name: Optional[str] + available: bool + + +class GuildSoundboardSound(SoundboardSound): + guild_id: NotRequired[Snowflake] + user: NotRequired[User] # only available via REST, given appropriate permissions + + +class ListGuildSoundboardSounds(TypedDict): + items: List[GuildSoundboardSound] diff --git a/disnake/types/voice.py b/disnake/types/voice.py index 4ad9bc36b9..3b59312ddf 100644 --- a/disnake/types/voice.py +++ b/disnake/types/voice.py @@ -66,3 +66,5 @@ class VoiceChannelEffect(TypedDict, total=False): emoji: Optional[PartialEmoji] animation_type: Optional[VoiceChannelEffectAnimationType] animation_id: int + sound_id: Snowflake + sound_volume: float diff --git a/disnake/utils.py b/disnake/utils.py index 1cbd012344..a99b5558e4 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -508,10 +508,12 @@ def _maybe_cast(value: V, converter: Callable[[V], T], default: T = None) -> Opt "image/jpeg": ".jpg", "image/gif": ".gif", "image/webp": ".webp", + "audio/mpeg": ".mp3", + "audio/ogg": ".ogg", } -def _get_mime_type_for_image(data: _BytesLike) -> str: +def _get_mime_type_for_data(data: _BytesLike) -> str: if data[0:8] == b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": return "image/png" elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): @@ -520,20 +522,26 @@ def _get_mime_type_for_image(data: _BytesLike) -> str: return "image/gif" elif data[0:4] == b"RIFF" and data[8:12] == b"WEBP": return "image/webp" + elif data[0:3] == b"ID3" or data[0:2] in (b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"): + # n.b. this doesn't support the unofficial MPEG-2.5 frame header (which starts with 0xFFEx). + # Discord also doesn't accept it. + return "audio/mpeg" + elif data[0:4] == b"OggS": + return "audio/ogg" else: - raise ValueError("Unsupported image type given") + raise ValueError("Unsupported file type provided") def _bytes_to_base64_data(data: _BytesLike) -> str: fmt = "data:{mime};base64,{data}" - mime = _get_mime_type_for_image(data) + mime = _get_mime_type_for_data(data) b64 = b64encode(data).decode("ascii") return fmt.format(mime=mime, data=b64) -def _get_extension_for_image(data: _BytesLike) -> Optional[str]: +def _get_extension_for_data(data: _BytesLike) -> Optional[str]: try: - mime_type = _get_mime_type_for_image(data) + mime_type = _get_mime_type_for_data(data) except ValueError: return None return _mime_type_extensions.get(mime_type) diff --git a/docs/api/audit_logs.rst b/docs/api/audit_logs.rst index 29c03d52ab..448a3ffcf5 100644 --- a/docs/api/audit_logs.rst +++ b/docs/api/audit_logs.rst @@ -525,11 +525,15 @@ AuditLogDiff .. attribute:: emoji - The name of the sticker's or role's emoji being changed. + For stickers or roles, the emoji name of the target being changed + (this will be of type :class:`str`). - See also :attr:`GuildSticker.emoji` or :attr:`Role.emoji`. + For soundboard sounds, the associated emoji of the target being changed + (this will be of type Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]]). - :type: :class:`str` + See also :attr:`GuildSticker.emoji`, :attr:`Role.emoji`, or :attr:`GuildSoundboardSound.emoji`. + + :type: Union[:class:`str`, Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]]] .. attribute:: description @@ -716,6 +720,12 @@ AuditLogDiff :type: Optional[:class:`ThreadSortOrder`] + .. attribute:: volume + + The volume of a soundboard sound being changed. + + :type: :class:`float` + Enumerations ------------ @@ -1585,6 +1595,44 @@ AuditLogAction .. versionchanged:: 2.6 Added support for :class:`PartialIntegration`, and added ``integration`` to :attr:`~AuditLogEntry.extra`. + .. attribute:: soundboard_sound_create + + A soundboard sound was created. + + Possible attributes for :class:`AuditLogDiff`: + + - :attr:`~AuditLogDiff.id` + - :attr:`~AuditLogDiff.name` + - :attr:`~AuditLogDiff.volume` + - :attr:`~AuditLogDiff.emoji` + + .. versionadded:: 2.10 + + .. attribute:: soundboard_sound_update + + A soundboard sound was updated. + + Possible attributes for :class:`AuditLogDiff`: + + - :attr:`~AuditLogDiff.name` + - :attr:`~AuditLogDiff.volume` + - :attr:`~AuditLogDiff.emoji` + + .. versionadded:: 2.10 + + .. attribute:: soundboard_sound_delete + + A soundboard sound was deleted. + + Possible attributes for :class:`AuditLogDiff`: + + - :attr:`~AuditLogDiff.id` + - :attr:`~AuditLogDiff.name` + - :attr:`~AuditLogDiff.volume` + - :attr:`~AuditLogDiff.emoji` + + .. versionadded:: 2.10 + .. attribute:: automod_rule_create An auto moderation rule was created. diff --git a/docs/api/events.rst b/docs/api/events.rst index c5683e456c..bd4f16652a 100644 --- a/docs/api/events.rst +++ b/docs/api/events.rst @@ -670,7 +670,7 @@ Emojis Called when a :class:`Guild` adds or removes :class:`Emoji`. - This requires :attr:`Intents.emojis_and_stickers` to be enabled. + This requires :attr:`Intents.expressions` to be enabled. :param guild: The guild who got their emojis updated. :type guild: :class:`Guild` @@ -978,6 +978,24 @@ Scheduled Events :param payload: The raw event payload data. :type payload: :class:`RawGuildScheduledEventUserActionEvent` +Soundboard +++++++++++ + +.. function:: on_guild_soundboard_sounds_update(guild, before, after) + + Called when a :class:`Guild` updates its soundboard sounds. + + This requires :attr:`Intents.expressions` to be enabled. + + .. versionadded:: 2.10 + + :param guild: The guild who got their soundboard sounds updated. + :type guild: :class:`Guild` + :param before: A list of soundboard sounds before the update. + :type before: Sequence[:class:`GuildSoundboardSound`] + :param after: A list of soundboard sounds after the update. + :type after: Sequence[:class:`GuildSoundboardSound`] + Stage Instances +++++++++++++++ @@ -1014,7 +1032,7 @@ Stickers Called when a :class:`Guild` updates its stickers. - This requires :attr:`Intents.emojis_and_stickers` to be enabled. + This requires :attr:`Intents.expressions` to be enabled. .. versionadded:: 2.0 diff --git a/docs/api/index.rst b/docs/api/index.rst index 8b1dea42da..c29ba29e6f 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -116,6 +116,7 @@ Documents permissions roles skus + soundboard stage_instances stickers users diff --git a/docs/api/soundboard.rst b/docs/api/soundboard.rst new file mode 100644 index 0000000000..28f5a3e916 --- /dev/null +++ b/docs/api/soundboard.rst @@ -0,0 +1,45 @@ +.. SPDX-License-Identifier: MIT + +.. currentmodule:: disnake + +Soundboard +========== + +This section documents everything related to Discord +:ddocs:`soundboards `. + +Discord Models +-------------- + +PartialSoundboardSound +~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: PartialSoundboardSound + +.. autoclass:: PartialSoundboardSound() + :members: + :inherited-members: + +SoundboardSound +~~~~~~~~~~~~~~~ + +.. attributetable:: SoundboardSound + +.. autoclass:: SoundboardSound() + :members: + :inherited-members: + +GuildSoundboardSound +~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: GuildSoundboardSound + +.. autoclass:: GuildSoundboardSound() + :members: + :inherited-members: + + +Events +------ + +- :func:`on_guild_soundboard_sounds_update(guild, before, after) ` diff --git a/docs/ext/commands/api/converters.rst b/docs/ext/commands/api/converters.rst index 012dca1261..d2cec39784 100644 --- a/docs/ext/commands/api/converters.rst +++ b/docs/ext/commands/api/converters.rst @@ -78,6 +78,9 @@ Classes .. autoclass:: GuildStickerConverter :members: +.. autoclass:: GuildSoundboardSoundConverter + :members: + .. autoclass:: PermissionsConverter :members: diff --git a/docs/ext/commands/api/exceptions.rst b/docs/ext/commands/api/exceptions.rst index bde1d7d446..6f6549be6e 100644 --- a/docs/ext/commands/api/exceptions.rst +++ b/docs/ext/commands/api/exceptions.rst @@ -120,6 +120,9 @@ Exceptions .. autoexception:: GuildStickerNotFound :members: +.. autoexception:: GuildSoundboardSoundNotFound + :members: + .. autoexception:: GuildScheduledEventNotFound :members: @@ -213,6 +216,7 @@ Exception Hierarchy - :exc:`EmojiNotFound` - :exc:`PartialEmojiConversionFailure` - :exc:`GuildStickerNotFound` + - :exc:`GuildSoundboardSoundNotFound` - :exc:`GuildScheduledEventNotFound` - :exc:`BadBoolArgument` - :exc:`LargeIntConversionFailure` diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index 18024647ca..3fd5ca0be0 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -399,6 +399,7 @@ A lot of Discord models work out of the gate as a parameter: - :class:`GuildSticker` (since v2.0) - :class:`Permissions` (since v2.3) - :class:`GuildScheduledEvent` (since v2.5) +- :class:`GuildSoundboardSound` (since v2.10) Having any of these set as the converter will intelligently convert the argument to the appropriate target type you specify. @@ -455,6 +456,8 @@ converter is given below: +------------------------------+--------------------------------------------------------+ | :class:`GuildScheduledEvent` | :class:`~ext.commands.GuildScheduledEventConverter` | +------------------------------+--------------------------------------------------------+ +| :class:`GuildSoundboardSound`| :class:`~ext.commands.GuildSoundboardSoundConverter` | ++------------------------------+--------------------------------------------------------+ By providing the converter it allows us to use them as building blocks for another converter: diff --git a/tests/test_utils.py b/tests/test_utils.py index 75e3944151..ba4dca5e15 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -267,17 +267,20 @@ def test_maybe_cast() -> None: (b"\x47\x49\x46\x38\x37\x61", "image/gif", ".gif"), (b"\x47\x49\x46\x38\x39\x61", "image/gif", ".gif"), (b"RIFFxxxxWEBP", "image/webp", ".webp"), + (b"ID3", "audio/mpeg", ".mp3"), + (b"\xFF\xF3", "audio/mpeg", ".mp3"), + (b"OggS", "audio/ogg", ".ogg"), ], ) def test_mime_type_valid(data, expected_mime, expected_ext) -> None: for d in (data, data + b"\xFF"): - assert utils._get_mime_type_for_image(d) == expected_mime - assert utils._get_extension_for_image(d) == expected_ext + assert utils._get_mime_type_for_data(d) == expected_mime + assert utils._get_extension_for_data(d) == expected_ext prefixed = b"\xFF" + data - with pytest.raises(ValueError, match=r"Unsupported image type given"): - utils._get_mime_type_for_image(prefixed) - assert utils._get_extension_for_image(prefixed) is None + with pytest.raises(ValueError, match=r"Unsupported file type provided"): + utils._get_mime_type_for_data(prefixed) + assert utils._get_extension_for_data(prefixed) is None @pytest.mark.parametrize( @@ -291,9 +294,9 @@ def test_mime_type_valid(data, expected_mime, expected_ext) -> None: ], ) def test_mime_type_invalid(data) -> None: - with pytest.raises(ValueError, match=r"Unsupported image type given"): - utils._get_mime_type_for_image(data) - assert utils._get_extension_for_image(data) is None + with pytest.raises(ValueError, match=r"Unsupported file type provided"): + utils._get_mime_type_for_data(data) + assert utils._get_extension_for_data(data) is None @pytest.mark.asyncio From 4f6a371e794b42f4d6d2b327ccdeaa868297eab9 Mon Sep 17 00:00:00 2001 From: vi <8530778+shiftinv@users.noreply.github.com> Date: Sat, 28 Dec 2024 16:39:42 +0100 Subject: [PATCH 3/4] fix(typing): improve view type inference of ui decorators (#1190) --- changelog/1190.feature.rst | 1 + disnake/ui/button.py | 32 +++++++++------------ disnake/ui/item.py | 31 +++++++-------------- disnake/ui/select/base.py | 20 +++++-------- disnake/ui/select/channel.py | 22 +++++++-------- disnake/ui/select/mentionable.py | 24 ++++++++-------- disnake/ui/select/role.py | 22 +++++++-------- disnake/ui/select/string.py | 23 ++++++++------- disnake/ui/select/user.py | 22 +++++++-------- disnake/ui/view.py | 8 +++--- pyproject.toml | 3 +- tests/ui/test_decorators.py | 48 +++++++++++++------------------- 12 files changed, 113 insertions(+), 143 deletions(-) create mode 100644 changelog/1190.feature.rst diff --git a/changelog/1190.feature.rst b/changelog/1190.feature.rst new file mode 100644 index 0000000000..6fd323a472 --- /dev/null +++ b/changelog/1190.feature.rst @@ -0,0 +1 @@ +The ``cls`` parameter of UI component decorators (such as :func:`ui.button`) now accepts any matching callable, in addition to item subclasses. diff --git a/disnake/ui/button.py b/disnake/ui/button.py index 9995013ebb..bfcccb663f 100644 --- a/disnake/ui/button.py +++ b/disnake/ui/button.py @@ -10,10 +10,8 @@ Callable, Optional, Tuple, - Type, TypeVar, Union, - get_origin, overload, ) @@ -21,7 +19,7 @@ from ..enums import ButtonStyle, ComponentType from ..partial_emoji import PartialEmoji, _EmojiTag from ..utils import MISSING -from .item import DecoratedItem, Item, ItemShape +from .item import DecoratedItem, Item __all__ = ( "Button", @@ -263,20 +261,20 @@ def button( style: ButtonStyle = ButtonStyle.secondary, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[Button[V_co]]], DecoratedItem[Button[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, Button[V_co]]], DecoratedItem[Button[V_co]]]: ... @overload def button( - cls: Type[ItemShape[B_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[B_co]], DecoratedItem[B_co]]: + cls: Callable[P, B_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, B_co]], DecoratedItem[B_co]]: ... def button( - cls: Type[ItemShape[B_co, ...]] = Button[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[B_co]], DecoratedItem[B_co]]: + cls: Callable[..., B_co] = Button[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, B_co]], DecoratedItem[B_co]]: """A decorator that attaches a button to a component. The function being decorated should have three parameters, ``self`` representing @@ -293,13 +291,12 @@ def button( Parameters ---------- - cls: Type[:class:`Button`] - The button subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`Button`] + A callable (may be a :class:`Button` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. .. versionadded:: 2.6 - label: Optional[:class:`str`] The label of the button, if any. custom_id: Optional[:class:`str`] @@ -319,13 +316,10 @@ def button( For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ - if (origin := get_origin(cls)) is not None: - cls = origin - - if not isinstance(cls, type) or not issubclass(cls, Button): - raise TypeError(f"cls argument must be a subclass of Button, got {cls!r}") + if not callable(cls): + raise TypeError("cls argument must be callable") - def decorator(func: ItemCallbackType[B_co]) -> DecoratedItem[B_co]: + def decorator(func: ItemCallbackType[V_co, B_co]) -> DecoratedItem[B_co]: if not asyncio.iscoroutinefunction(func): raise TypeError("button function must be a coroutine function") diff --git a/disnake/ui/item.py b/disnake/ui/item.py index c4d29c6417..284e839378 100644 --- a/disnake/ui/item.py +++ b/disnake/ui/item.py @@ -12,17 +12,18 @@ Optional, Protocol, Tuple, + Type, TypeVar, overload, ) __all__ = ("Item", "WrappedComponent") -ItemT = TypeVar("ItemT", bound="Item") +I = TypeVar("I", bound="Item[Any]") V_co = TypeVar("V_co", bound="Optional[View]", covariant=True) if TYPE_CHECKING: - from typing_extensions import ParamSpec, Self + from typing_extensions import Self from ..client import Client from ..components import NestedComponent @@ -31,7 +32,7 @@ from ..types.components import Component as ComponentPayload from .view import View - ItemCallbackType = Callable[[Any, ItemT, MessageInteraction], Coroutine[Any, Any, Any]] + ItemCallbackType = Callable[[V_co, I, MessageInteraction], Coroutine[Any, Any, Any]] else: ParamSpec = TypeVar @@ -160,29 +161,17 @@ async def callback(self, interaction: MessageInteraction[ClientT], /) -> None: pass -I_co = TypeVar("I_co", bound=Item, covariant=True) +SelfViewT = TypeVar("SelfViewT", bound="Optional[View]") -# while the decorators don't actually return a descriptor that matches this protocol, +# While the decorators don't actually return a descriptor that matches this protocol, # this protocol ensures that type checkers don't complain about statements like `self.button.disabled = True`, -# which work as `View.__init__` replaces the handler with the item -class DecoratedItem(Protocol[I_co]): +# which work as `View.__init__` replaces the handler with the item. +class DecoratedItem(Protocol[I]): @overload - def __get__(self, obj: None, objtype: Any) -> ItemCallbackType: + def __get__(self, obj: None, objtype: Type[SelfViewT]) -> ItemCallbackType[SelfViewT, I]: ... @overload - def __get__(self, obj: Any, objtype: Any) -> I_co: - ... - - -T_co = TypeVar("T_co", covariant=True) -P = ParamSpec("P") - - -class ItemShape(Protocol[T_co, P]): - def __new__(cls) -> T_co: - ... - - def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: + def __get__(self, obj: Any, objtype: Any) -> I: ... diff --git a/disnake/ui/select/base.py b/disnake/ui/select/base.py index 912a24ba1f..10cae4f4c9 100644 --- a/disnake/ui/select/base.py +++ b/disnake/ui/select/base.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, - Any, Callable, ClassVar, Generic, @@ -19,14 +18,13 @@ Type, TypeVar, Union, - get_origin, ) from ...components import AnySelectMenu, SelectDefaultValue from ...enums import ComponentType, SelectDefaultValueType from ...object import Object from ...utils import MISSING, humanize_list -from ..item import DecoratedItem, Item, ItemShape +from ..item import DecoratedItem, Item __all__ = ("BaseSelect",) @@ -239,24 +237,20 @@ def _transform_default_values( def _create_decorator( - cls: Type[ItemShape[S_co, P]], - # only for input validation - base_cls: Type[BaseSelect[Any, Any, Any]], + # FIXME(3.0): rename `cls` parameter to more closely represent any callable argument type + cls: Callable[P, S_co], /, *args: P.args, **kwargs: P.kwargs, -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: if args: # the `*args` def above is just to satisfy the typechecker raise RuntimeError("expected no *args") - if (origin := get_origin(cls)) is not None: - cls = origin + if not callable(cls): + raise TypeError("cls argument must be callable") - if not isinstance(cls, type) or not issubclass(cls, base_cls): - raise TypeError(f"cls argument must be a subclass of {base_cls.__name__}, got {cls!r}") - - def decorator(func: ItemCallbackType[S_co]) -> DecoratedItem[S_co]: + def decorator(func: ItemCallbackType[V_co, S_co]) -> DecoratedItem[S_co]: if not asyncio.iscoroutinefunction(func): raise TypeError("select function must be a coroutine function") diff --git a/disnake/ui/select/channel.py b/disnake/ui/select/channel.py index f004308482..f27c7a2107 100644 --- a/disnake/ui/select/channel.py +++ b/disnake/ui/select/channel.py @@ -30,7 +30,7 @@ from typing_extensions import Self from ...abc import AnyChannel - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -197,20 +197,20 @@ def channel_select( channel_types: Optional[List[ChannelType]] = None, default_values: Optional[Sequence[SelectDefaultValueInputType[AnyChannel]]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[ChannelSelect[V_co]]], DecoratedItem[ChannelSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, ChannelSelect[V_co]]], DecoratedItem[ChannelSelect[V_co]]]: ... @overload def channel_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def channel_select( - cls: Type[ItemShape[S_co, ...]] = ChannelSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = ChannelSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a channel select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -224,10 +224,10 @@ def channel_select( Parameters ---------- - cls: Type[:class:`ChannelSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`ChannelSelect`] + A callable (may be a :class:`ChannelSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -256,4 +256,4 @@ def channel_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, ChannelSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/select/mentionable.py b/disnake/ui/select/mentionable.py index e98dfb29c9..1cc0be5b8a 100644 --- a/disnake/ui/select/mentionable.py +++ b/disnake/ui/select/mentionable.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -174,20 +174,22 @@ def mentionable_select( Sequence[SelectDefaultValueMultiInputType[Union[User, Member, Role]]] ] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[MentionableSelect[V_co]]], DecoratedItem[MentionableSelect[V_co]]]: +) -> Callable[ + [ItemCallbackType[V_co, MentionableSelect[V_co]]], DecoratedItem[MentionableSelect[V_co]] +]: ... @overload def mentionable_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def mentionable_select( - cls: Type[ItemShape[S_co, ...]] = MentionableSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = MentionableSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a mentionable (user/member/role) select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -201,10 +203,10 @@ def mentionable_select( Parameters ---------- - cls: Type[:class:`MentionableSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`MentionableSelect`] + A callable (may be a :class:`MentionableSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -232,4 +234,4 @@ def mentionable_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, MentionableSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/select/role.py b/disnake/ui/select/role.py index 4cb886168f..439749a136 100644 --- a/disnake/ui/select/role.py +++ b/disnake/ui/select/role.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -161,20 +161,20 @@ def role_select( disabled: bool = False, default_values: Optional[Sequence[SelectDefaultValueInputType[Role]]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[RoleSelect[V_co]]], DecoratedItem[RoleSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, RoleSelect[V_co]]], DecoratedItem[RoleSelect[V_co]]]: ... @overload def role_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def role_select( - cls: Type[ItemShape[S_co, ...]] = RoleSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = RoleSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a role select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -188,10 +188,10 @@ def role_select( Parameters ---------- - cls: Type[:class:`RoleSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`RoleSelect`] + A callable (may be a :class:`RoleSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -217,4 +217,4 @@ def role_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, RoleSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/select/string.py b/disnake/ui/select/string.py index 3b12d80388..b336dfa388 100644 --- a/disnake/ui/select/string.py +++ b/disnake/ui/select/string.py @@ -29,7 +29,7 @@ from ...emoji import Emoji from ...partial_emoji import PartialEmoji - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -265,20 +265,20 @@ def string_select( options: SelectOptionInput = ..., disabled: bool = False, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[StringSelect[V_co]]], DecoratedItem[StringSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, StringSelect[V_co]]], DecoratedItem[StringSelect[V_co]]]: ... @overload def string_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def string_select( - cls: Type[ItemShape[S_co, ...]] = StringSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = StringSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a string select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -293,13 +293,12 @@ def string_select( Parameters ---------- - cls: Type[:class:`StringSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`StringSelect`] + A callable (may be a :class:`StringSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. .. versionadded:: 2.6 - placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -329,7 +328,7 @@ def string_select( disabled: :class:`bool` Whether the select is disabled. Defaults to ``False``. """ - return _create_decorator(cls, StringSelect, **kwargs) + return _create_decorator(cls, **kwargs) select = string_select # backwards compatibility diff --git a/disnake/ui/select/user.py b/disnake/ui/select/user.py index 9ab9b803ce..2dd20d40f6 100644 --- a/disnake/ui/select/user.py +++ b/disnake/ui/select/user.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -163,20 +163,20 @@ def user_select( disabled: bool = False, default_values: Optional[Sequence[SelectDefaultValueInputType[Union[User, Member]]]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[UserSelect[V_co]]], DecoratedItem[UserSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, UserSelect[V_co]]], DecoratedItem[UserSelect[V_co]]]: ... @overload def user_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def user_select( - cls: Type[ItemShape[S_co, ...]] = UserSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = UserSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a user select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -190,10 +190,10 @@ def user_select( Parameters ---------- - cls: Type[:class:`UserSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`UserSelect`] + A callable (may be a :class:`UserSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -219,4 +219,4 @@ def user_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, UserSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/view.py b/disnake/ui/view.py index 71c2965074..ffaa90fa3c 100644 --- a/disnake/ui/view.py +++ b/disnake/ui/view.py @@ -153,10 +153,10 @@ class View: """ __discord_ui_view__: ClassVar[bool] = True - __view_children_items__: ClassVar[List[ItemCallbackType[Item]]] = [] + __view_children_items__: ClassVar[List[ItemCallbackType[Self, Item[Self]]]] = [] def __init_subclass__(cls) -> None: - children: List[ItemCallbackType[Item]] = [] + children: List[ItemCallbackType[Self, Item[Self]]] = [] for base in reversed(cls.__mro__): for member in base.__dict__.values(): if hasattr(member, "__discord_ui_model_type__"): @@ -169,9 +169,9 @@ def __init_subclass__(cls) -> None: def __init__(self, *, timeout: Optional[float] = 180.0) -> None: self.timeout = timeout - self.children: List[Item] = [] + self.children: List[Item[Self]] = [] for func in self.__view_children_items__: - item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) + item: Item[Self] = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) item.callback = partial(func, self, item) item._view = self setattr(self, func.__name__, item) diff --git a/pyproject.toml b/pyproject.toml index ed9467bfec..777507fb7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,6 +213,8 @@ ignore = [ "S311", # insecure RNG usage, we don't use these for security-related things "PLE0237", # pyright seems to catch this already + "E741", # ambiguous variable names + # temporary disables, to fix later "D205", # blank line required between summary and description "D401", # first line of docstring should be in imperative mood @@ -248,7 +250,6 @@ ignore = [ "T201", # print found, printing is okay in examples ] "examples/basic_voice.py" = ["S104"] # possible binding to all interfaces -"examples/views/tic_tac_toe.py" = ["E741"] # ambigious variable name: `O` [tool.ruff.lint.isort] combine-as-imports = true diff --git a/tests/ui/test_decorators.py b/tests/ui/test_decorators.py index e9c3680873..7dbe6aa488 100644 --- a/tests/ui/test_decorators.py +++ b/tests/ui/test_decorators.py @@ -9,12 +9,15 @@ from disnake import ui from disnake.ui.button import V_co -T = TypeVar("T", bound=ui.Item) +V = TypeVar("V", bound=ui.View) +I = TypeVar("I", bound=ui.Item) @contextlib.contextmanager -def create_callback(item_type: Type[T]) -> Iterator["ui.item.ItemCallbackType[T]"]: - async def callback(self, item, inter) -> None: +def create_callback( + view_type: Type[V], item_type: Type[I] +) -> Iterator["ui.item.ItemCallbackType[V, I]"]: + async def callback(self: V, item: I, inter) -> None: pytest.fail("callback should not be invoked") yield callback @@ -28,33 +31,36 @@ def __init__(self, *, param: float = 42.0) -> None: pass +class _CustomView(ui.View): + ... + + class TestDecorator: def test_default(self) -> None: - with create_callback(ui.Button[ui.View]) as func: + with create_callback(_CustomView, ui.Button[ui.View]) as func: res = ui.button(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.Button[ui.View]]) + assert_type(res, ui.item.DecoratedItem[ui.Button[_CustomView]]) - assert func.__discord_ui_model_type__ is ui.Button + assert func.__discord_ui_model_type__ is ui.Button[Any] assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} - with create_callback(ui.StringSelect[ui.View]) as func: + with create_callback(_CustomView, ui.StringSelect[ui.View]) as func: res = ui.string_select(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.StringSelect[ui.View]]) + assert_type(res, ui.item.DecoratedItem[ui.StringSelect[_CustomView]]) - assert func.__discord_ui_model_type__ is ui.StringSelect + assert func.__discord_ui_model_type__ is ui.StringSelect[Any] assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} # from here on out we're mostly only testing the button decorator, # as @ui.string_select etc. works identically @pytest.mark.parametrize("cls", [_CustomButton, _CustomButton[Any]]) - def test_cls(self, cls: Type[_CustomButton]) -> None: - with create_callback(cls) as func: + def test_cls(self, cls: Type[_CustomButton[ui.View]]) -> None: + with create_callback(_CustomView, cls) as func: res = ui.button(cls=cls, param=1337)(func) assert_type(res, ui.item.DecoratedItem[cls]) - # should strip to origin type - assert func.__discord_ui_model_type__ is _CustomButton + assert func.__discord_ui_model_type__ is cls assert func.__discord_ui_model_kwargs__ == {"param": 1337} # typing-only check @@ -63,19 +69,3 @@ def _test_typing_cls(self) -> None: cls=_CustomButton, this_should_not_work="h", # type: ignore ) - - @pytest.mark.parametrize( - ("decorator", "invalid_cls"), - [ - (ui.button, ui.StringSelect), - (ui.string_select, ui.Button), - (ui.user_select, ui.Button), - (ui.role_select, ui.Button), - (ui.mentionable_select, ui.Button), - (ui.channel_select, ui.Button), - ], - ) - def test_cls_invalid(self, decorator, invalid_cls) -> None: - for cls in [123, int, invalid_cls]: - with pytest.raises(TypeError, match=r"cls argument must be"): - decorator(cls=cls) From df5e3915c905c014f8343f5913231b25ea0fc06d Mon Sep 17 00:00:00 2001 From: vi <8530778+shiftinv@users.noreply.github.com> Date: Sat, 28 Dec 2024 17:55:06 +0100 Subject: [PATCH 4/4] fix(modal): fix timeout edge cases with `custom_id` reuse and long-running callbacks (#914) --- changelog/914.bugfix.rst | 1 + disnake/ui/modal.py | 104 ++++++++++++++++++++++++--------- examples/interactions/modal.py | 2 +- test_bot/cogs/modals.py | 2 +- 4 files changed, 80 insertions(+), 29 deletions(-) create mode 100644 changelog/914.bugfix.rst diff --git a/changelog/914.bugfix.rst b/changelog/914.bugfix.rst new file mode 100644 index 0000000000..6dd6dcc4bf --- /dev/null +++ b/changelog/914.bugfix.rst @@ -0,0 +1 @@ +Fix :class:`ui.Modal` timeout issues with long-running callbacks, and multiple modals with the same user and ``custom_id``. diff --git a/disnake/ui/modal.py b/disnake/ui/modal.py index adf21ffa9c..7f0192c3b8 100644 --- a/disnake/ui/modal.py +++ b/disnake/ui/modal.py @@ -6,7 +6,8 @@ import os import sys import traceback -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar, Union +from functools import partial +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, TypeVar, Union from ..enums import TextInputStyle from ..utils import MISSING @@ -38,14 +39,32 @@ class Modal: components: |components_type| The components to display in the modal. Up to 5 action rows. custom_id: :class:`str` - The custom ID of the modal. + The custom ID of the modal. This is usually not required. + If not given, then a unique one is generated for you. + + .. note:: + :class:`Modal`\\s are identified based on the user ID that triggered the + modal, and this ``custom_id``. + This can result in collisions when a user opens a modal with the same ``custom_id`` on + two separate devices, for example. + + To avoid such issues, consider not specifying a ``custom_id`` to use an automatically generated one, + or include a unique value in the custom ID (e.g. the original interaction ID). + timeout: :class:`float` The time to wait until the modal is removed from cache, if no interaction is made. Modals without timeouts are not supported, since there's no event for when a modal is closed. Defaults to 600 seconds. """ - __slots__ = ("title", "custom_id", "components", "timeout") + __slots__ = ( + "title", + "custom_id", + "components", + "timeout", + "__remove_callback", + "__timeout_handle", + ) def __init__( self, @@ -67,6 +86,11 @@ def __init__( self.components: List[ActionRow] = rows self.timeout: float = timeout + # function for the modal to remove itself from the store, if any + self.__remove_callback: Optional[Callable[[Modal], None]] = None + # timer handle for the scheduled timeout + self.__timeout_handle: Optional[asyncio.TimerHandle] = None + def __repr__(self) -> str: return ( f" None: except Exception as e: await self.on_error(e, interaction) finally: - # if the interaction was responded to (no matter if in the callback or error handler), - # the modal closed for the user and therefore can be removed from the store - if interaction.response._response_type is not None: - interaction._state._modal_store.remove_modal( - interaction.author.id, interaction.custom_id - ) + if interaction.response._response_type is None: + # If the interaction was not successfully responded to, the modal didn't close for the user. + # Since the timeout was already stopped at this point, restart it. + self._start_listening(self.__remove_callback) + else: + # Otherwise, the modal closed for the user; remove it from the store. + self._stop_listening() + + def _start_listening(self, remove_callback: Optional[Callable[[Modal], None]]) -> None: + self.__remove_callback = remove_callback + + loop = asyncio.get_running_loop() + if self.__timeout_handle is not None: + # shouldn't get here, but handled just in case + self.__timeout_handle.cancel() + + # start timeout + self.__timeout_handle = loop.call_later(self.timeout, self._dispatch_timeout) + + def _stop_listening(self) -> None: + # cancel timeout + if self.__timeout_handle is not None: + self.__timeout_handle.cancel() + self.__timeout_handle = None + + # remove modal from store + if self.__remove_callback is not None: + self.__remove_callback(self) + self.__remove_callback = None + + def _dispatch_timeout(self) -> None: + self._stop_listening() + asyncio.create_task(self.on_timeout(), name=f"disnake-ui-modal-timeout-{self.custom_id}") def dispatch(self, interaction: ModalInteraction) -> None: + # stop the timeout, but don't remove the modal from the store yet in case the + # response fails and the modal stays open + if self.__timeout_handle is not None: + self.__timeout_handle.cancel() + asyncio.create_task( self._scheduled_task(interaction), name=f"disnake-ui-modal-dispatch-{self.custom_id}" ) @@ -232,28 +288,22 @@ def __init__(self, state: ConnectionState) -> None: self._modals: Dict[Tuple[int, str], Modal] = {} def add_modal(self, user_id: int, modal: Modal) -> None: - loop = asyncio.get_running_loop() - self._modals[(user_id, modal.custom_id)] = modal - loop.create_task(self.handle_timeout(user_id, modal.custom_id, modal.timeout)) + key = (user_id, modal.custom_id) - def remove_modal(self, user_id: int, modal_custom_id: str) -> Modal: - return self._modals.pop((user_id, modal_custom_id)) + # if another modal with the same user+custom_id already exists, + # stop its timeout to avoid overlaps/collisions + if (existing := self._modals.get(key)) is not None: + existing._stop_listening() - async def handle_timeout(self, user_id: int, modal_custom_id: str, timeout: float) -> None: - # Waits for the timeout and then removes the modal from cache, this is done just in case - # the user closed the modal, as there isn't an event for that. + # start timeout, store modal + remove_callback = partial(self.remove_modal, user_id) + modal._start_listening(remove_callback) + self._modals[key] = modal - await asyncio.sleep(timeout) - try: - modal = self.remove_modal(user_id, modal_custom_id) - except KeyError: - # The modal has already been removed. - pass - else: - await modal.on_timeout() + def remove_modal(self, user_id: int, modal: Modal) -> None: + self._modals.pop((user_id, modal.custom_id), None) def dispatch(self, interaction: ModalInteraction) -> None: key = (interaction.author.id, interaction.custom_id) - modal = self._modals.get(key) - if modal is not None: + if (modal := self._modals.get(key)) is not None: modal.dispatch(interaction) diff --git a/examples/interactions/modal.py b/examples/interactions/modal.py index f271c82f4c..311b1d7d46 100644 --- a/examples/interactions/modal.py +++ b/examples/interactions/modal.py @@ -43,7 +43,7 @@ def __init__(self) -> None: max_length=1024, ), ] - super().__init__(title="Create Tag", custom_id="create_tag", components=components) + super().__init__(title="Create Tag", components=components) async def callback(self, inter: disnake.ModalInteraction) -> None: tag_name = inter.text_values["name"] diff --git a/test_bot/cogs/modals.py b/test_bot/cogs/modals.py index c5d514a25c..e988c88284 100644 --- a/test_bot/cogs/modals.py +++ b/test_bot/cogs/modals.py @@ -22,7 +22,7 @@ def __init__(self) -> None: style=TextInputStyle.paragraph, ), ] - super().__init__(title="Create Tag", custom_id="create_tag", components=components) + super().__init__(title="Create Tag", components=components) async def callback(self, inter: disnake.ModalInteraction[commands.Bot]) -> None: embed = disnake.Embed(title="Tag Creation")