Skip to content

Commit

Permalink
feat(typing): make Interaction and subclasses generic (#1037)
Browse files Browse the repository at this point in the history
Signed-off-by: Snipy7374 <[email protected]>
  • Loading branch information
Snipy7374 authored Sep 5, 2023
1 parent 0aaabf6 commit 80c8b32
Show file tree
Hide file tree
Showing 19 changed files with 85 additions and 61 deletions.
1 change: 1 addition & 0 deletions changelog/1036.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`Interaction` and subtypes accept the bot type as a generic parameter to denote the type returned by the :attr:`~Interaction.bot` and :attr:`~Interaction.client` properties.
21 changes: 14 additions & 7 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
T = TypeVar("T", bound=Any)
TypeT = TypeVar("TypeT", bound=Type[Any])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
BotT = TypeVar("BotT", bound="disnake.Client", covariant=True)

__all__ = (
"Range",
Expand Down Expand Up @@ -520,11 +521,11 @@ class ParamInfo:

def __init__(
self,
default: Union[Any, Callable[[ApplicationCommandInteraction], Any]] = ...,
default: Union[Any, Callable[[ApplicationCommandInteraction[BotT]], Any]] = ...,
*,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
converter: Optional[Callable[[ApplicationCommandInteraction, Any], Any]] = None,
converter: Optional[Callable[[ApplicationCommandInteraction[BotT], Any], Any]] = None,
convert_default: bool = False,
autocomplete: Optional[AnyAutocompleter] = None,
choices: Optional[Choices] = None,
Expand Down Expand Up @@ -911,6 +912,7 @@ def isolate_self(
parametersl.pop(0)
if parametersl:
annot = parametersl[0].annotation
annot = get_origin(annot) or annot
if issubclass_(annot, ApplicationCommandInteraction) or annot is inspect.Parameter.empty:
inter_param = parameters.pop(parametersl[0].name)

Expand Down Expand Up @@ -982,7 +984,9 @@ def collect_params(
injections[parameter.name] = default
elif parameter.annotation in Injection._registered:
injections[parameter.name] = Injection._registered[parameter.annotation]
elif issubclass_(parameter.annotation, ApplicationCommandInteraction):
elif issubclass_(
get_origin(parameter.annotation) or parameter.annotation, ApplicationCommandInteraction
):
if inter_param is None:
inter_param = parameter
else:
Expand Down Expand Up @@ -1116,21 +1120,24 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
if param.autocomplete:
command.autocompleters[param.name] = param.autocomplete

if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction):
if issubclass_(
get_origin(annot := sig.parameters[inter_param].annotation) or annot,
disnake.GuildCommandInteraction,
):
command._guild_only = True

return [param.to_option() for param in params]


def Param(
default: Union[Any, Callable[[ApplicationCommandInteraction], Any]] = ...,
default: Union[Any, Callable[[ApplicationCommandInteraction[BotT]], Any]] = ...,
*,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
choices: Optional[Choices] = None,
converter: Optional[Callable[[ApplicationCommandInteraction, Any], Any]] = None,
converter: Optional[Callable[[ApplicationCommandInteraction[BotT], Any], Any]] = None,
convert_defaults: bool = False,
autocomplete: Optional[Callable[[ApplicationCommandInteraction, str], Any]] = None,
autocomplete: Optional[Callable[[ApplicationCommandInteraction[BotT], str], Any]] = None,
channel_types: Optional[List[ChannelType]] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
Expand Down
10 changes: 5 additions & 5 deletions disnake/interactions/application_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..member import Member
from ..message import Message
from ..user import User
from .base import Interaction, InteractionDataResolved
from .base import ClientT, Interaction, InteractionDataResolved

__all__ = (
"ApplicationCommandInteraction",
Expand Down Expand Up @@ -41,7 +41,7 @@
)


class ApplicationCommandInteraction(Interaction):
class ApplicationCommandInteraction(Interaction[ClientT]):
"""Represents an interaction with an application command.
Current examples are slash commands, user commands and message commands.
Expand Down Expand Up @@ -119,7 +119,7 @@ def filled_options(self) -> Dict[str, Any]:
return kwargs


class GuildCommandInteraction(ApplicationCommandInteraction):
class GuildCommandInteraction(ApplicationCommandInteraction[ClientT]):
"""An :class:`ApplicationCommandInteraction` subclass, primarily meant for annotations.
This prevents the command from being invoked in DMs by automatically setting
Expand All @@ -137,7 +137,7 @@ class GuildCommandInteraction(ApplicationCommandInteraction):
me: Member


class UserCommandInteraction(ApplicationCommandInteraction):
class UserCommandInteraction(ApplicationCommandInteraction[ClientT]):
"""An :class:`ApplicationCommandInteraction` subclass meant for annotations.
No runtime behavior is changed but annotations are modified
Expand All @@ -147,7 +147,7 @@ class UserCommandInteraction(ApplicationCommandInteraction):
target: Union[User, Member]


class MessageCommandInteraction(ApplicationCommandInteraction):
class MessageCommandInteraction(ApplicationCommandInteraction[ClientT]):
"""An :class:`ApplicationCommandInteraction` subclass meant for annotations.
No runtime behavior is changed but annotations are modified
Expand Down
16 changes: 7 additions & 9 deletions disnake/interactions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Mapping,
Optional,
Expand Down Expand Up @@ -95,9 +96,10 @@
MISSING: Any = utils.MISSING

T = TypeVar("T")
ClientT = TypeVar("ClientT", bound="Client", covariant=True)


class Interaction:
class Interaction(Generic[ClientT]):
"""A base class representing a user-initiated Discord interaction.
An interaction happens when a user performs an action that the client needs to
Expand Down Expand Up @@ -175,7 +177,7 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState) -> None:
self._state: ConnectionState = state
# TODO: Maybe use a unique session
self._session: ClientSession = state.http._HTTPClient__session # type: ignore
self.client: Client = state._get_client()
self.client: ClientT = cast(ClientT, state._get_client())
self._original_response: Optional[InteractionMessage] = None

self.id: int = int(data["id"])
Expand Down Expand Up @@ -208,13 +210,9 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState) -> None:
self.author = self._state.store_user(user)

@property
def bot(self) -> AnyBot:
""":class:`~disnake.ext.commands.Bot`: The bot handling the interaction.
Only applicable when used with :class:`~disnake.ext.commands.Bot`.
This is an alias for :attr:`.client`.
"""
return self.client # type: ignore
def bot(self) -> ClientT:
""":class:`~disnake.ext.commands.Bot`: An alias for :attr:`.client`."""
return self.client

@property
def created_at(self) -> datetime:
Expand Down
4 changes: 2 additions & 2 deletions disnake/interactions/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..enums import ComponentType, try_enum
from ..message import Message
from ..utils import cached_slot_property
from .base import Interaction, InteractionDataResolved
from .base import ClientT, Interaction, InteractionDataResolved

__all__ = (
"MessageInteraction",
Expand All @@ -28,7 +28,7 @@
from .base import InteractionChannel


class MessageInteraction(Interaction):
class MessageInteraction(Interaction[ClientT]):
"""Represents an interaction with a message component.
Current examples are buttons and dropdowns.
Expand Down
4 changes: 2 additions & 2 deletions disnake/interactions/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..enums import ComponentType
from ..message import Message
from ..utils import cached_slot_property
from .base import Interaction
from .base import ClientT, Interaction

if TYPE_CHECKING:
from ..state import ConnectionState
Expand All @@ -21,7 +21,7 @@
__all__ = ("ModalInteraction", "ModalInteractionData")


class ModalInteraction(Interaction):
class ModalInteraction(Interaction[ClientT]):
"""Represents an interaction with a modal.
.. versionadded:: 2.4
Expand Down
5 changes: 4 additions & 1 deletion disnake/ui/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
if TYPE_CHECKING:
from typing_extensions import ParamSpec, Self

from ..client import Client
from ..components import NestedComponent
from ..enums import ComponentType
from ..interactions import MessageInteraction
Expand All @@ -35,6 +36,8 @@
else:
ParamSpec = TypeVar

ClientT = TypeVar("ClientT", bound="Client")


class WrappedComponent(ABC):
"""Represents the base UI component that all UI components inherit from.
Expand Down Expand Up @@ -142,7 +145,7 @@ def view(self) -> V_co:
"""Optional[:class:`View`]: The underlying view for this item."""
return self._view

async def callback(self, interaction: MessageInteraction, /) -> None:
async def callback(self, interaction: MessageInteraction[ClientT], /) -> None:
"""|coro|
The callback associated with this UI item.
Expand Down
9 changes: 6 additions & 3 deletions disnake/ui/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import os
import sys
import traceback
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar, Union

from ..enums import TextInputStyle
from ..utils import MISSING
from .action_row import ActionRow, components_to_rows
from .text_input import TextInput

if TYPE_CHECKING:
from ..client import Client
from ..interactions.modal import ModalInteraction
from ..state import ConnectionState
from ..types.components import Modal as ModalPayload
Expand All @@ -22,6 +23,8 @@

__all__ = ("Modal",)

ClientT = TypeVar("ClientT", bound="Client")


class Modal:
"""Represents a UI Modal.
Expand Down Expand Up @@ -156,7 +159,7 @@ def add_text_input(
)
)

async def callback(self, interaction: ModalInteraction, /) -> None:
async def callback(self, interaction: ModalInteraction[ClientT], /) -> None:
"""|coro|
The callback associated with this modal.
Expand All @@ -170,7 +173,7 @@ async def callback(self, interaction: ModalInteraction, /) -> None:
"""
pass

async def on_error(self, error: Exception, interaction: ModalInteraction) -> None:
async def on_error(self, error: Exception, interaction: ModalInteraction[ClientT]) -> None:
"""|coro|
A callback that is called when an error occurs.
Expand Down
6 changes: 3 additions & 3 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
self.__doc__ = function.__doc__

@overload
def __get__(self, instance: None, owner: Type[T]) -> Self:
def __get__(self, instance: None, owner: Type[Any]) -> Self:
...

@overload
def __get__(self, instance: T, owner: Type[T]) -> T_co:
def __get__(self, instance: T, owner: Type[Any]) -> T_co:
...

def __get__(self, instance: Optional[T], owner: Type[T]) -> Any:
def __get__(self, instance: Optional[T], owner: Type[Any]) -> Any:
if instance is None:
return self

Expand Down
2 changes: 1 addition & 1 deletion examples/interactions/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# which can be set using `Param` and the `converter` argument.
@bot.slash_command()
async def clean_command(
inter: disnake.CommandInteraction,
inter: disnake.CommandInteraction[commands.Bot],
text: str = commands.Param(converter=lambda inter, text: text.replace("@", "\\@")),
):
...
Expand Down
2 changes: 1 addition & 1 deletion examples/interactions/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def description(
# by using `Param` and passing a callable.
@bot.slash_command()
async def defaults(
inter: disnake.CommandInteraction,
inter: disnake.CommandInteraction[commands.Bot],
string: str = "this is a default value",
user: disnake.User = commands.Param(lambda inter: inter.author),
):
Expand Down
12 changes: 9 additions & 3 deletions test_bot/cogs/guild_scheduled_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ def __init__(self, bot: commands.Bot) -> None:

@commands.slash_command()
async def fetch_event(
self, inter: disnake.GuildCommandInteraction, id: commands.LargeInt
self, inter: disnake.GuildCommandInteraction[commands.Bot], id: commands.LargeInt
) -> None:
gse = await inter.guild.fetch_scheduled_event(id)
await inter.response.send_message(str(gse.image))

@commands.slash_command()
async def edit_event(
self, inter: disnake.GuildCommandInteraction, id: commands.LargeInt, new_image: bool
self,
inter: disnake.GuildCommandInteraction[commands.Bot],
id: commands.LargeInt,
new_image: bool,
) -> None:
await inter.response.defer()
gse = await inter.guild.fetch_scheduled_event(id)
Expand All @@ -33,7 +36,10 @@ async def edit_event(

@commands.slash_command()
async def create_event(
self, inter: disnake.GuildCommandInteraction, name: str, channel: disnake.VoiceChannel
self,
inter: disnake.GuildCommandInteraction[commands.Bot],
name: str,
channel: disnake.VoiceChannel,
) -> None:
image = disnake.File("./assets/banner.png")
gse = await inter.guild.create_scheduled_event(
Expand Down
8 changes: 4 additions & 4 deletions test_bot/cogs/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, prefix: str, suffix: str = "") -> None:
self.prefix = prefix
self.suffix = suffix

def __call__(self, inter: disnake.CommandInteraction, a: str = "init"):
def __call__(self, inter: disnake.CommandInteraction[commands.Bot], a: str = "init"):
return self.prefix + a + self.suffix


Expand All @@ -41,7 +41,7 @@ def __init__(self, username: str, discriminator: str) -> None:
self.discriminator = discriminator

@commands.converter_method
async def convert(cls, inter: disnake.CommandInteraction, user: disnake.User):
async def convert(cls, inter: disnake.CommandInteraction[commands.Bot], user: disnake.User):
return cls(user.name, user.discriminator)

def __repr__(self) -> str:
Expand Down Expand Up @@ -89,7 +89,7 @@ async def injected_method(self, number: int = 3):
@commands.slash_command()
async def injection_command(
self,
inter: disnake.CommandInteraction,
inter: disnake.CommandInteraction[commands.Bot],
sqrt: Optional[float] = commands.Param(None, converter=lambda i, x: x**0.5),
prefixed: str = commands.Param(converter=PrefixConverter("__", "__")),
other: Tuple[int, str] = commands.inject(injected),
Expand All @@ -109,7 +109,7 @@ async def injection_command(
@commands.slash_command()
async def discerned_injections(
self,
inter: disnake.CommandInteraction,
inter: disnake.CommandInteraction[commands.Bot],
perhaps: PerhapsThis,
god: Optional[HopeToGod] = None,
) -> None:
Expand Down
Loading

0 comments on commit 80c8b32

Please sign in to comment.