From 975657a743d0eb056025f21758c24deeb525738e Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Sat, 24 Aug 2024 13:55:50 +0200 Subject: [PATCH] feat(interactions): deserialize `channel` from data (#1012) --- changelog/1012.feature.rst | 1 + disnake/interactions/application_command.py | 26 ++-- disnake/interactions/base.py | 125 +++++++++++--------- disnake/interactions/message.py | 26 ++-- disnake/interactions/modal.py | 17 ++- disnake/state.py | 54 +++++++++ disnake/types/interactions.py | 9 +- tests/interactions/test_base.py | 42 ++++--- 8 files changed, 203 insertions(+), 97 deletions(-) create mode 100644 changelog/1012.feature.rst diff --git a/changelog/1012.feature.rst b/changelog/1012.feature.rst new file mode 100644 index 0000000000..a59debb4e2 --- /dev/null +++ b/changelog/1012.feature.rst @@ -0,0 +1 @@ +:class:`Interaction`\s now always have a proper :attr:`~Interaction.channel` attribute, even when the bot is not part of the guild or cannot access the channel due to other reasons. diff --git a/disnake/interactions/application_command.py b/disnake/interactions/application_command.py index d25f6a2530..46eee43985 100644 --- a/disnake/interactions/application_command.py +++ b/disnake/interactions/application_command.py @@ -58,8 +58,21 @@ class ApplicationCommandInteraction(Interaction[ClientT]): The application ID that the interaction was for. guild_id: Optional[:class:`int`] The guild ID the interaction was sent from. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` @@ -103,7 +116,7 @@ def __init__( ) -> None: super().__init__(data=data, state=state) self.data: ApplicationCommandInteractionData = ApplicationCommandInteractionData( - data=data["data"], state=state, guild_id=self.guild_id + data=data["data"], parent=self ) self.application_command: InvokableApplicationCommand = MISSING self.command_failed: bool = False @@ -200,17 +213,14 @@ def __init__( self, *, data: ApplicationCommandInteractionDataPayload, - state: ConnectionState, - guild_id: Optional[int], + parent: ApplicationCommandInteraction[ClientT], ) -> None: super().__init__(data) self.id: int = int(data["id"]) self.name: str = data["name"] self.type: ApplicationCommandType = try_enum(ApplicationCommandType, data["type"]) - self.resolved = InteractionDataResolved( - data=data.get("resolved", {}), state=state, guild_id=guild_id - ) + self.resolved = InteractionDataResolved(data=data.get("resolved", {}), parent=parent) self.target_id: Optional[int] = utils._get_as_snowflake(data, "target_id") target = self.resolved.get_by_id(self.target_id) self.target: Optional[Union[User, Member, Message]] = target # type: ignore diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index 9543cabc66..7e43874bcd 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -21,10 +21,9 @@ from .. import utils from ..app_commands import OptionChoice -from ..channel import PartialMessageable, _threaded_guild_channel_factory +from ..channel import PartialMessageable from ..entitlement import Entitlement from ..enums import ( - ChannelType, ComponentType, InteractionResponseType, InteractionType, @@ -76,7 +75,6 @@ from ..mentions import AllowedMentions from ..poll import Poll from ..state import ConnectionState - from ..threads import Thread from ..types.components import Modal as ModalPayload from ..types.interactions import ( ApplicationCommandOptionChoice as ApplicationCommandOptionChoicePayload, @@ -90,7 +88,8 @@ from .message import MessageInteraction from .modal import ModalInteraction - InteractionChannel = Union[GuildChannel, Thread, PartialMessageable] + InteractionMessageable = Union[GuildMessageable, PartialMessageable] + InteractionChannel = Union[InteractionMessageable, GuildChannel] AnyBot = Union[Bot, AutoShardedBot] @@ -131,8 +130,21 @@ class Interaction(Generic[ClientT]): .. versionchanged:: 2.5 Changed to :class:`Locale` instead of :class:`str`. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` @@ -159,7 +171,7 @@ class Interaction(Generic[ClientT]): "id", "type", "guild_id", - "channel_id", + "channel", "application_id", "author", "token", @@ -175,7 +187,6 @@ class Interaction(Generic[ClientT]): "_original_response", "_cs_response", "_cs_followup", - "_cs_channel", "_cs_me", "_cs_expires_at", ) @@ -193,8 +204,6 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState) -> None: self.token: str = data["token"] self.version: int = data["version"] self.application_id: int = int(data["application_id"]) - - self.channel_id: int = int(data["channel_id"]) self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") self.locale: Locale = try_enum(Locale, data["locale"]) @@ -208,17 +217,29 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState) -> None: # one of user and member will always exist self.author: Union[User, Member] = MISSING - if self.guild_id and (member := data.get("member")): - guild: Guild = self.guild or Object(id=self.guild_id) # type: ignore + guild_fallback: Optional[Union[Guild, Object]] = None + if self.guild_id: + guild_fallback = self.guild or Object(self.guild_id) + + if guild_fallback and (member := data.get("member")): self.author = ( - isinstance(guild, Guild) - and guild.get_member(int(member["user"]["id"])) - or Member(state=self._state, guild=guild, data=member) + isinstance(guild_fallback, Guild) + and guild_fallback.get_member(int(member["user"]["id"])) + or Member( + state=self._state, + guild=guild_fallback, # type: ignore # may be `Object` + data=member, + ) ) self._permissions = int(member.get("permissions", 0)) elif user := data.get("user"): self.author = self._state.store_user(user) + # TODO: consider making this optional in 3.0 + self.channel: InteractionMessageable = state._get_partial_interaction_channel( + data["channel"], guild_fallback, return_messageable=True + ) + self.entitlements: List[Entitlement] = ( [Entitlement(data=e, state=state) for e in entitlements_data] if (entitlements_data := data.get("entitlements")) @@ -256,24 +277,13 @@ def me(self) -> Union[Member, ClientUser]: return None if self.bot is None else self.bot.user # type: ignore return self.guild.me - @utils.cached_slot_property("_cs_channel") - def channel(self) -> Union[GuildMessageable, PartialMessageable]: - """Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`]: The channel the interaction was sent from. - - Note that due to a Discord limitation, threads that the bot cannot access and DM channels - are not resolved since there is no data to complete them. - These are :class:`PartialMessageable` instead. + @property + def channel_id(self) -> int: + """The channel ID the interaction was sent from. - If you want to compute the interaction author's or bot's permissions in the channel, - consider using :attr:`permissions` or :attr:`app_permissions` instead. + See also :attr:`channel`. """ - guild = self.guild - channel = guild and guild._resolve_channel(self.channel_id) - if channel is None: - # could be a thread channel in a guild, or a DM channel - type = None if self.guild_id is not None else ChannelType.private - return PartialMessageable(state=self._state, id=self.channel_id, type=type) - return channel # type: ignore + return self.channel.id @property def permissions(self) -> Permissions: @@ -1873,8 +1883,7 @@ def __init__( self, *, data: InteractionDataResolvedPayload, - state: ConnectionState, - guild_id: Optional[int], + parent: Interaction[ClientT], ) -> None: data = data or {} super().__init__(data) @@ -1893,6 +1902,9 @@ def __init__( messages = data.get("messages", {}) attachments = data.get("attachments", {}) + state = parent._state + guild_id = parent.guild_id + guild: Optional[Guild] = None # `guild_fallback` is only used in guild contexts, so this `MISSING` value should never be used. # We need to define it anyway to satisfy the typechecker. @@ -1925,36 +1937,35 @@ def __init__( data=role, ) - for str_id, channel in channels.items(): - channel_id = int(str_id) - factory, _ = _threaded_guild_channel_factory(channel["type"]) - if factory: - channel["position"] = 0 # type: ignore - self.channels[channel_id] = ( - guild - and guild.get_channel_or_thread(channel_id) - or factory( - guild=guild_fallback, - state=state, - data=channel, # type: ignore - ) - ) - else: - # TODO: guild_directory is not messageable - self.channels[channel_id] = PartialMessageable( - state=state, id=channel_id, type=try_enum(ChannelType, channel["type"]) - ) + for str_id, channel_data in channels.items(): + self.channels[int(str_id)] = state._get_partial_interaction_channel( + channel_data, guild_fallback + ) for str_id, message in messages.items(): channel_id = int(message["channel_id"]) - channel = cast( - "Optional[MessageableChannel]", - (guild and guild.get_channel(channel_id) or state.get_channel(channel_id)), - ) + channel: Optional[MessageableChannel] = None + + if ( + channel_id == parent.channel.id + # we still want to fall back to state.get_channel when the + # parent channel is a dm/group channel, for now. + # FIXME: remove this once `parent.channel` supports `DMChannel` + and not isinstance(parent.channel, PartialMessageable) + ): + # fast path, this should generally be the case + channel = parent.channel + else: + channel = cast( + "Optional[MessageableChannel]", + (guild and guild.get_channel(channel_id) or state.get_channel(channel_id)), + ) + if channel is None: - # The channel is not part of `resolved.channels`, + # n.b. the message's channel is not sent as part of `resolved.channels`, # so we need to fall back to partials here. channel = PartialMessageable(state=state, id=channel_id, type=None) + self.messages[int(str_id)] = Message(state=state, channel=channel, data=message) for str_id, attachment in attachments.items(): diff --git a/disnake/interactions/message.py b/disnake/interactions/message.py index 4ef51165d5..8ce8c3d3ab 100644 --- a/disnake/interactions/message.py +++ b/disnake/interactions/message.py @@ -47,8 +47,21 @@ class MessageInteraction(Interaction[ClientT]): The token to continue the interaction. These are valid for 15 minutes. guild_id: Optional[:class:`int`] The guild ID the interaction was sent from. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` @@ -85,9 +98,7 @@ class MessageInteraction(Interaction[ClientT]): def __init__(self, *, data: MessageInteractionPayload, state: ConnectionState) -> None: super().__init__(data=data, state=state) - self.data: MessageInteractionData = MessageInteractionData( - data=data["data"], state=state, guild_id=self.guild_id - ) + self.data: MessageInteractionData = MessageInteractionData(data=data["data"], parent=self) self.message = Message(state=self._state, channel=self.channel, data=data["message"]) @property @@ -167,8 +178,7 @@ def __init__( self, *, data: MessageComponentInteractionDataPayload, - state: ConnectionState, - guild_id: Optional[int], + parent: MessageInteraction[ClientT], ) -> None: super().__init__(data) self.custom_id: str = data["custom_id"] @@ -179,7 +189,7 @@ def __init__( empty_resolved: InteractionDataResolvedPayload = {} # pyright shenanigans self.resolved = InteractionDataResolved( - data=data.get("resolved", empty_resolved), state=state, guild_id=guild_id + data=data.get("resolved", empty_resolved), parent=parent ) def __repr__(self) -> str: diff --git a/disnake/interactions/modal.py b/disnake/interactions/modal.py index f631c38ac2..be9520b1cf 100644 --- a/disnake/interactions/modal.py +++ b/disnake/interactions/modal.py @@ -39,8 +39,21 @@ class ModalInteraction(Interaction[ClientT]): These are valid for 15 minutes. guild_id: Optional[:class:`int`] The guild ID the interaction was sent from. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` diff --git a/disnake/state.py b/disnake/state.py index ab4e5a8d78..84798d2fa5 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -43,6 +43,7 @@ TextChannel, VoiceChannel, _guild_channel_factory, + _threaded_guild_channel_factory, ) from .emoji import Emoji from .entitlement import Entitlement @@ -96,11 +97,13 @@ from .gateway import DiscordWebSocket from .guild import GuildChannel, VocalGuildChannel from .http import HTTPClient + from .interactions.base import InteractionChannel, InteractionMessageable from .types import gateway from .types.activity import Activity as ActivityPayload from .types.channel import DMChannel as DMChannelPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload from .types.guild import Guild as GuildPayload, UnavailableGuild as UnavailableGuildPayload + from .types.interactions import InteractionChannel as InteractionChannelPayload from .types.message import Message as MessagePayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.user import User as UserPayload @@ -2029,6 +2032,57 @@ def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmo except KeyError: return emoji + @overload + def _get_partial_interaction_channel( + self, + data: InteractionChannelPayload, + guild: Optional[Union[Guild, Object]], + *, + return_messageable: Literal[False] = False, + ) -> InteractionChannel: + ... + + @overload + def _get_partial_interaction_channel( + self, + data: InteractionChannelPayload, + guild: Optional[Union[Guild, Object]], + *, + return_messageable: Literal[True], + ) -> InteractionMessageable: + ... + + # note: this resolves private channels (and unknown types) to `PartialMessageable` + def _get_partial_interaction_channel( + self, + data: InteractionChannelPayload, + guild: Optional[Union[Guild, Object]], + *, + # this param is purely for type-checking, it has no effect on runtime behavior. + return_messageable: bool = False, + ) -> InteractionChannel: + channel_id = int(data["id"]) + channel_type = data["type"] + + factory, _ = _threaded_guild_channel_factory(channel_type) + if not factory or not guild: + return PartialMessageable( + state=self, + id=channel_id, + type=try_enum(ChannelType, channel_type), + ) + + data.setdefault("position", 0) # type: ignore + return ( + isinstance(guild, Guild) + and guild.get_channel_or_thread(channel_id) + or factory( + guild=guild, # type: ignore # FIXME: create proper fallback guild instead of passing Object + state=self, + data=data, # type: ignore # generic payload type + ) + ) + def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]: if id is None: return None diff --git a/disnake/types/interactions.py b/disnake/types/interactions.py index efffa8e599..88498da81f 100644 --- a/disnake/types/interactions.py +++ b/disnake/types/interactions.py @@ -89,7 +89,7 @@ class GuildApplicationCommandPermissions(TypedDict): InteractionType = Literal[1, 2, 3, 4, 5] -class ResolvedPartialChannel(TypedDict): +class InteractionChannel(TypedDict): id: Snowflake type: ChannelType permissions: str @@ -104,7 +104,7 @@ class InteractionDataResolved(TypedDict, total=False): users: Dict[Snowflake, User] members: Dict[Snowflake, Member] roles: Dict[Snowflake, Role] - channels: Dict[Snowflake, ResolvedPartialChannel] + channels: Dict[Snowflake, InteractionChannel] # only in application commands messages: Dict[Snowflake, Message] attachments: Dict[Snowflake, Attachment] @@ -258,9 +258,10 @@ class _BaseInteraction(TypedDict): # common properties in non-ping interactions class _BaseUserInteraction(_BaseInteraction): - # the docs specify `channel_id` as optional, - # but it is assumed to always exist on non-ping interactions + # the docs specify `channel_id` and 'channel` as optional, + # but they're assumed to always exist on non-ping interactions channel_id: Snowflake + channel: InteractionChannel locale: str app_permissions: NotRequired[str] guild_id: NotRequired[Snowflake] diff --git a/tests/interactions/test_base.py b/tests/interactions/test_base.py index 24d937b685..5e364072dc 100644 --- a/tests/interactions/test_base.py +++ b/tests/interactions/test_base.py @@ -8,12 +8,12 @@ import pytest import disnake -from disnake import InteractionResponseType as ResponseType # shortcut +from disnake import Interaction, InteractionResponseType as ResponseType # shortcut from disnake.state import ConnectionState from disnake.utils import MISSING if TYPE_CHECKING: - from disnake.types.interactions import ResolvedPartialChannel as ResolvedPartialChannelPayload + from disnake.types.interactions import InteractionChannel as InteractionChannelPayload from disnake.types.member import Member as MemberPayload from disnake.types.user import User as UserPayload @@ -137,7 +137,14 @@ def state(self): s._get_guild.return_value = None return s - def test_init_member(self, state) -> None: + @pytest.fixture + def interaction(self, state): + i = mock.Mock(spec_set=Interaction) + i._state = state + i.guild_id = 1234 + return i + + def test_init_member(self, interaction) -> None: member_payload: MemberPayload = { "roles": [], "joined_at": "2022-09-02T22:00:55.069000+00:00", @@ -156,8 +163,7 @@ def test_init_member(self, state) -> None: # user only, should deserialize user object resolved = disnake.InteractionDataResolved( data={"users": {"1234": user_payload}}, - state=state, - guild_id=1234, + parent=interaction, ) assert len(resolved.members) == 0 assert len(resolved.users) == 1 @@ -165,8 +171,7 @@ def test_init_member(self, state) -> None: # member only, shouldn't deserialize anything resolved = disnake.InteractionDataResolved( data={"members": {"1234": member_payload}}, - state=state, - guild_id=1234, + parent=interaction, ) assert len(resolved.members) == 0 assert len(resolved.users) == 0 @@ -174,15 +179,14 @@ def test_init_member(self, state) -> None: # user + member, should deserialize member object only resolved = disnake.InteractionDataResolved( data={"users": {"1234": user_payload}, "members": {"1234": member_payload}}, - state=state, - guild_id=1234, + parent=interaction, ) assert len(resolved.members) == 1 assert len(resolved.users) == 0 - @pytest.mark.parametrize("channel_type", [t.value for t in disnake.ChannelType]) + @pytest.mark.parametrize("channel_type", [t.value for t in disnake.ChannelType] + [99]) def test_channel(self, state, channel_type) -> None: - channel_data: ResolvedPartialChannelPayload = { + channel_data: InteractionChannelPayload = { "id": "42", "type": channel_type, "permissions": "7", @@ -197,12 +201,14 @@ def test_channel(self, state, channel_type) -> None: "locked": False, } - resolved = disnake.InteractionDataResolved( - data={"channels": {"42": channel_data}}, state=state, guild_id=1234 + # this should not raise + channel = ConnectionState._get_partial_interaction_channel( + state, + channel_data, + disnake.Object(1234), + return_messageable=False, ) - assert len(resolved.channels) == 1 - channel = next(iter(resolved.channels.values())) - # should be partial if and only if it's a dm/group - # TODO: currently includes directory channels (14), see `InteractionDataResolved.__init__` - assert isinstance(channel, disnake.PartialMessageable) == (channel_type in (1, 3, 14)) + # should be partial if and only if it's a dm/group or unknown + # TODO: currently includes directory channels (14), see `_get_partial_interaction_channel` + assert isinstance(channel, disnake.PartialMessageable) == (channel_type in (1, 3, 14, 99))