diff --git a/disnake/_event_data.py b/disnake/_event_data.py index b62444e932..be7946a632 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -246,25 +246,25 @@ def __init__( arg_types=["StageInstance"], ), Event.application_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[Self]"], ), Event.application_command_autocomplete: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[Self]"], ), Event.button_click: EventData( - arg_types=["MessageInteraction"], + arg_types=["MessageInteraction[Self]"], ), Event.dropdown: EventData( - arg_types=["MessageInteraction"], + arg_types=["MessageInteraction[Self]"], ), Event.interaction: EventData( - arg_types=["Interaction"], + arg_types=["Interaction[Self]"], ), Event.message_interaction: EventData( - arg_types=["MessageInteraction"], + arg_types=["MessageInteraction[Self]"], ), Event.modal_submit: EventData( - arg_types=["ModalInteraction"], + arg_types=["ModalInteraction[Self]"], ), Event.message: EventData( arg_types=["Message"], @@ -330,39 +330,39 @@ def __init__( bot=True, ), Event.slash_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.slash_command_completion: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.slash_command_error: EventData( - arg_types=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], bot=True, ), Event.user_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.user_command_completion: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.user_command_error: EventData( - arg_types=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], bot=True, ), Event.message_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.message_command_completion: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.message_command_error: EventData( - arg_types=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], bot=True, ), } diff --git a/disnake/client.py b/disnake/client.py index ba8f49ebd7..63f1c65f69 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -79,6 +79,8 @@ from .widget import Widget if TYPE_CHECKING: + from typing_extensions import Self + from disnake.ext import commands from .abc import GuildChannel, Messageable, PrivateChannel, Snowflake, SnowflakeTime @@ -125,6 +127,9 @@ commands.InteractionBot, commands.AutoShardedInteractionBot, ] + # we can't use `typing.Self` when the `self: AnyBot` parameter is annotated, + # so go back to the old way of using a TypeVar for those overloads + AnyBotT = TypeVar("AnyBotT", bound=AnyBot) __all__ = ( @@ -2500,9 +2505,9 @@ def wait_for( self, event: Literal[Event.application_command, "application_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[Self]]: ... @overload @@ -2511,9 +2516,9 @@ def wait_for( self, event: Literal[Event.application_command_autocomplete, "application_command_autocomplete"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[Self]]: ... @overload @@ -2522,9 +2527,9 @@ def wait_for( self, event: Literal[Event.button_click, "button_click"], *, - check: Optional[Callable[[MessageInteraction], bool]] = None, + check: Optional[Callable[[MessageInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, MessageInteraction]: + ) -> Coroutine[Any, Any, MessageInteraction[Self]]: ... @overload @@ -2533,9 +2538,9 @@ def wait_for( self, event: Literal[Event.dropdown, "dropdown"], *, - check: Optional[Callable[[MessageInteraction], bool]] = None, + check: Optional[Callable[[MessageInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, MessageInteraction]: + ) -> Coroutine[Any, Any, MessageInteraction[Self]]: ... @overload @@ -2544,9 +2549,9 @@ def wait_for( self, event: Literal[Event.interaction, "interaction"], *, - check: Optional[Callable[[Interaction], bool]] = None, + check: Optional[Callable[[Interaction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Interaction]: + ) -> Coroutine[Any, Any, Interaction[Self]]: ... @overload @@ -2555,9 +2560,9 @@ def wait_for( self, event: Literal[Event.message_interaction, "message_interaction"], *, - check: Optional[Callable[[MessageInteraction], bool]] = None, + check: Optional[Callable[[MessageInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, MessageInteraction]: + ) -> Coroutine[Any, Any, MessageInteraction[Self]]: ... @overload @@ -2566,9 +2571,9 @@ def wait_for( self, event: Literal[Event.modal_submit, "modal_submit"], *, - check: Optional[Callable[[ModalInteraction], bool]] = None, + check: Optional[Callable[[ModalInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ModalInteraction]: + ) -> Coroutine[Any, Any, ModalInteraction[Self]]: ... @overload @@ -2798,106 +2803,106 @@ def wait_for( @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.slash_command, "slash_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.slash_command_completion, "slash_command_completion"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.slash_command_error, "slash_command_error"], *, check: Optional[ - Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool] ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.user_command, "user_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.user_command_completion, "user_command_completion"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.user_command_error, "user_command_error"], *, check: Optional[ - Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool] ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.message_command, "message_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.message_command_completion, "message_command_completion"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.message_command_error, "message_command_error"], *, check: Optional[ - Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool] ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]: ... # fallback for custom events diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index d617c99090..68a7b4908c 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -127,9 +127,10 @@ def generate_wait_for_overload( # set `self` annotation as a workaround for overloads in subclasses if event_data.bot: + self_type = "AnyBotT" if "AnyBotT" in new_annotation_str else "AnyBot" # fun. new_overload = new_overload.with_deep_changes( get_param(new_overload, "self"), - annotation=cst.Annotation(cst.Name("AnyBot")), + annotation=cst.Annotation(cst.Name(self_type)), ) return new_overload diff --git a/tests/test_events.py b/tests/test_events.py index a25f7dd771..fa4952195e 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -72,7 +72,7 @@ def _test_typing_wait_for(client: disnake.Client, bot: commands.Bot) -> None: _ = client.wait_for(Event.slash_command_error) # type: ignore # this should error _ = assert_type( bot.wait_for(Event.slash_command), - Coroutine[Any, Any, disnake.ApplicationCommandInteraction], + Coroutine[Any, Any, disnake.ApplicationCommandInteraction[commands.Bot]], )