diff --git a/disnake/_event_data.py b/disnake/_event_data.py index be7946a632..b25c46f3de 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from .enums import Event @@ -12,17 +12,17 @@ def __init__( self, *, arg_types: List[str], - bot: bool = False, + self_type: Optional[str] = None, event_only: bool = False, ) -> None: self.arg_types: Tuple[str, ...] = tuple(arg_types) - """Type names of event arguments, e.g. `("Guild", "User")`""" + """Type names of event arguments, e.g. `("Guild", "User")`.""" - self.bot: bool = bot - """Whether the event is specific to ext.commands""" + self.self_type: Optional[str] = self_type + """The annotation for the `self` parameter, used for bot-only events.""" self.event_only: bool = event_only - """Whether the event can only be used through `@event` and not other listeners""" + """Whether the event can only be used through `@event`, and not with listeners.""" EVENT_DATA: Dict[Event, EventData] = { @@ -317,52 +317,53 @@ def __init__( Event.raw_typing: EventData( arg_types=["RawTypingEvent"], ), + # bot-only: Event.command: EventData( - arg_types=["commands.Context"], - bot=True, + arg_types=["commands.Context[AnyPrefixBotT]"], + self_type="AnyPrefixBotT", ), Event.command_completion: EventData( - arg_types=["commands.Context"], - bot=True, + arg_types=["commands.Context[AnyPrefixBotT]"], + self_type="AnyPrefixBotT", ), Event.command_error: EventData( - arg_types=["commands.Context", "commands.CommandError"], - bot=True, + arg_types=["commands.Context[AnyPrefixBotT]", "commands.CommandError"], + self_type="AnyPrefixBotT", ), Event.slash_command: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.slash_command_completion: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.slash_command_error: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], - bot=True, + self_type="AnyBotT", ), Event.user_command: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.user_command_completion: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.user_command_error: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], - bot=True, + self_type="AnyBotT", ), Event.message_command: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.message_command_completion: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.message_command_error: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], - bot=True, + self_type="AnyBotT", ), } diff --git a/disnake/client.py b/disnake/client.py index 63f1c65f69..416fb69da1 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -121,14 +121,15 @@ from .types.gateway import SessionStartLimit as SessionStartLimitPayload from .voice_client import VoiceProtocol + AnyPrefixBot = Union[commands.Bot, commands.AutoShardedBot] AnyBot = Union[ - commands.Bot, - commands.AutoShardedBot, + AnyPrefixBot, commands.InteractionBot, commands.AutoShardedInteractionBot, ] # we can't use `typing.Self` when the `self: AnyBot` parameter is annotated, # so go back to the old way of using a TypeVar for those overloads + AnyPrefixBotT = TypeVar("AnyPrefixBotT", bound=AnyPrefixBot) AnyBotT = TypeVar("AnyBotT", bound=AnyBot) @@ -2770,34 +2771,36 @@ def wait_for( @overload @_generated def wait_for( - self: AnyBot, + self: AnyPrefixBotT, event: Literal[Event.command, "command"], *, - check: Optional[Callable[[commands.Context], bool]] = None, + check: Optional[Callable[[commands.Context[AnyPrefixBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, commands.Context]: + ) -> Coroutine[Any, Any, commands.Context[AnyPrefixBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyPrefixBotT, event: Literal[Event.command_completion, "command_completion"], *, - check: Optional[Callable[[commands.Context], bool]] = None, + check: Optional[Callable[[commands.Context[AnyPrefixBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, commands.Context]: + ) -> Coroutine[Any, Any, commands.Context[AnyPrefixBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyPrefixBotT, event: Literal[Event.command_error, "command_error"], *, - check: Optional[Callable[[commands.Context, commands.CommandError], bool]] = None, + check: Optional[ + Callable[[commands.Context[AnyPrefixBotT], commands.CommandError], bool] + ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[commands.Context, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[commands.Context[AnyPrefixBotT], commands.CommandError]]: ... @overload diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index bdcbe3cae2..d40b54ce52 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -69,7 +69,6 @@ from ..app_commands import Choices from ..client import Client from ..embeds import Embed - from ..ext.commands import AutoShardedBot, Bot from ..file import File from ..guild import GuildChannel, GuildMessageable from ..mentions import AllowedMentions @@ -90,8 +89,6 @@ InteractionChannel = Union[GuildChannel, Thread, PartialMessageable] - AnyBot = Union[Bot, AutoShardedBot] - MISSING: Any = utils.MISSING diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 68a7b4908c..9bbdb2fcea 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -126,11 +126,10 @@ def generate_wait_for_overload( new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) # set `self` annotation as a workaround for overloads in subclasses - if event_data.bot: - self_type = "AnyBotT" if "AnyBotT" in new_annotation_str else "AnyBot" # fun. + if event_data.self_type: new_overload = new_overload.with_deep_changes( get_param(new_overload, "self"), - annotation=cst.Annotation(cst.Name(self_type)), + annotation=cst.Annotation(cst.Name(event_data.self_type)), ) return new_overload