Skip to content

Commit

Permalink
fix(typing): add generic bot parameter to interaction types
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Sep 20, 2023
1 parent fd56f51 commit 600af60
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 59 deletions.
32 changes: 16 additions & 16 deletions disnake/_event_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,25 @@ def __init__(
arg_types=["StageInstance"],
),
Event.application_command: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[Self]"],
),
Event.application_command_autocomplete: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[Self]"],
),
Event.button_click: EventData(
arg_types=["MessageInteraction"],
arg_types=["MessageInteraction[Self]"],
),
Event.dropdown: EventData(
arg_types=["MessageInteraction"],
arg_types=["MessageInteraction[Self]"],
),
Event.interaction: EventData(
arg_types=["Interaction"],
arg_types=["Interaction[Self]"],
),
Event.message_interaction: EventData(
arg_types=["MessageInteraction"],
arg_types=["MessageInteraction[Self]"],
),
Event.modal_submit: EventData(
arg_types=["ModalInteraction"],
arg_types=["ModalInteraction[Self]"],
),
Event.message: EventData(
arg_types=["Message"],
Expand Down Expand Up @@ -330,39 +330,39 @@ def __init__(
bot=True,
),
Event.slash_command: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
),
Event.slash_command_completion: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
),
Event.slash_command_error: EventData(
arg_types=["ApplicationCommandInteraction", "commands.CommandError"],
arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"],
bot=True,
),
Event.user_command: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
),
Event.user_command_completion: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
),
Event.user_command_error: EventData(
arg_types=["ApplicationCommandInteraction", "commands.CommandError"],
arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"],
bot=True,
),
Event.message_command: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
),
Event.message_command_completion: EventData(
arg_types=["ApplicationCommandInteraction"],
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
),
Event.message_command_error: EventData(
arg_types=["ApplicationCommandInteraction", "commands.CommandError"],
arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"],
bot=True,
),
}
87 changes: 46 additions & 41 deletions disnake/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
from .widget import Widget

if TYPE_CHECKING:
from typing_extensions import Self

from disnake.ext import commands

from .abc import GuildChannel, Messageable, PrivateChannel, Snowflake, SnowflakeTime
Expand Down Expand Up @@ -125,6 +127,9 @@
commands.InteractionBot,
commands.AutoShardedInteractionBot,
]
# we can't use `typing.Self` when the `self: AnyBot` parameter is annotated,
# so go back to the old way of using a TypeVar for those overloads
AnyBotT = TypeVar("AnyBotT", bound=AnyBot)


__all__ = (
Expand Down Expand Up @@ -2500,9 +2505,9 @@ def wait_for(
self,
event: Literal[Event.application_command, "application_command"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[Self]]:
...

@overload
Expand All @@ -2511,9 +2516,9 @@ def wait_for(
self,
event: Literal[Event.application_command_autocomplete, "application_command_autocomplete"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[Self]]:
...

@overload
Expand All @@ -2522,9 +2527,9 @@ def wait_for(
self,
event: Literal[Event.button_click, "button_click"],
*,
check: Optional[Callable[[MessageInteraction], bool]] = None,
check: Optional[Callable[[MessageInteraction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, MessageInteraction]:
) -> Coroutine[Any, Any, MessageInteraction[Self]]:
...

@overload
Expand All @@ -2533,9 +2538,9 @@ def wait_for(
self,
event: Literal[Event.dropdown, "dropdown"],
*,
check: Optional[Callable[[MessageInteraction], bool]] = None,
check: Optional[Callable[[MessageInteraction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, MessageInteraction]:
) -> Coroutine[Any, Any, MessageInteraction[Self]]:
...

@overload
Expand All @@ -2544,9 +2549,9 @@ def wait_for(
self,
event: Literal[Event.interaction, "interaction"],
*,
check: Optional[Callable[[Interaction], bool]] = None,
check: Optional[Callable[[Interaction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, Interaction]:
) -> Coroutine[Any, Any, Interaction[Self]]:
...

@overload
Expand All @@ -2555,9 +2560,9 @@ def wait_for(
self,
event: Literal[Event.message_interaction, "message_interaction"],
*,
check: Optional[Callable[[MessageInteraction], bool]] = None,
check: Optional[Callable[[MessageInteraction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, MessageInteraction]:
) -> Coroutine[Any, Any, MessageInteraction[Self]]:
...

@overload
Expand All @@ -2566,9 +2571,9 @@ def wait_for(
self,
event: Literal[Event.modal_submit, "modal_submit"],
*,
check: Optional[Callable[[ModalInteraction], bool]] = None,
check: Optional[Callable[[ModalInteraction[Self]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ModalInteraction]:
) -> Coroutine[Any, Any, ModalInteraction[Self]]:
...

@overload
Expand Down Expand Up @@ -2798,106 +2803,106 @@ def wait_for(
@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.slash_command, "slash_command"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.slash_command_completion, "slash_command_completion"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.slash_command_error, "slash_command_error"],
*,
check: Optional[
Callable[[ApplicationCommandInteraction, commands.CommandError], bool]
Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool]
] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]:
) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.user_command, "user_command"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.user_command_completion, "user_command_completion"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.user_command_error, "user_command_error"],
*,
check: Optional[
Callable[[ApplicationCommandInteraction, commands.CommandError], bool]
Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool]
] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]:
) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.message_command, "message_command"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.message_command_completion, "message_command_completion"],
*,
check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None,
check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, ApplicationCommandInteraction]:
) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyBotT,
event: Literal[Event.message_command_error, "message_command_error"],
*,
check: Optional[
Callable[[ApplicationCommandInteraction, commands.CommandError], bool]
Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool]
] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]:
) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]:
...

# fallback for custom events
Expand Down
3 changes: 2 additions & 1 deletion scripts/codemods/typed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def generate_wait_for_overload(

# set `self` annotation as a workaround for overloads in subclasses
if event_data.bot:
self_type = "AnyBotT" if "AnyBotT" in new_annotation_str else "AnyBot" # fun.
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "self"),
annotation=cst.Annotation(cst.Name("AnyBot")),
annotation=cst.Annotation(cst.Name(self_type)),
)

return new_overload
2 changes: 1 addition & 1 deletion tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _test_typing_wait_for(client: disnake.Client, bot: commands.Bot) -> None:
_ = client.wait_for(Event.slash_command_error) # type: ignore # this should error
_ = assert_type(
bot.wait_for(Event.slash_command),
Coroutine[Any, Any, disnake.ApplicationCommandInteraction],
Coroutine[Any, Any, disnake.ApplicationCommandInteraction[commands.Bot]],
)


Expand Down

0 comments on commit 600af60

Please sign in to comment.