From fe583fc85c2baa75c6e66d28435bcfef670bd5c9 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Wed, 26 Apr 2023 20:31:58 +0200 Subject: [PATCH 01/22] feat(codemod): autogen `Client.wait_for` overloads --- disnake/client.py | 1229 +++++++++++++++++++++++++++++- disnake/utils.py | 14 +- pyproject.toml | 1 + scripts/codemods/typed_events.py | 251 ++++++ 4 files changed, 1483 insertions(+), 12 deletions(-) create mode 100644 scripts/codemods/typed_events.py diff --git a/disnake/client.py b/disnake/client.py index 56d60280f8..c84012a031 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -67,23 +67,51 @@ from .state import ConnectionState from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .template import Template -from .threads import Thread +from .threads import Thread, ThreadMember from .ui.view import View from .user import ClientUser, User -from .utils import MISSING +from .utils import MISSING, _generated, _overload_with_events from .voice_client import VoiceClient from .voice_region import VoiceRegion from .webhook import Webhook from .widget import Widget if TYPE_CHECKING: - from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime + from disnake.ext import commands + + from .abc import GuildChannel, Messageable, PrivateChannel, Snowflake, SnowflakeTime from .app_commands import APIApplicationCommand from .asset import AssetBytes - from .channel import DMChannel + from .audit_logs import AuditLogEntry + from .automod import AutoModActionExecution, AutoModRule + from .channel import DMChannel, ForumChannel, GroupChannel from .enums import Event - from .member import Member + from .guild_scheduled_event import GuildScheduledEvent + from .integrations import Integration + from .interactions import ( + ApplicationCommandInteraction, + Interaction, + MessageInteraction, + ModalInteraction, + ) + from .member import Member, VoiceState from .message import Message + from .raw_models import ( + RawBulkMessageDeleteEvent, + RawGuildMemberRemoveEvent, + RawGuildScheduledEventUserActionEvent, + RawIntegrationDeleteEvent, + RawMessageDeleteEvent, + RawMessageUpdateEvent, + RawReactionActionEvent, + RawReactionClearEmojiEvent, + RawReactionClearEvent, + RawThreadDeleteEvent, + RawThreadMemberRemoveEvent, + RawTypingEvent, + ) + from .reaction import Reaction + from .role import Role from .types.application_role_connection import ( ApplicationRoleConnectionMetadata as ApplicationRoleConnectionMetadataPayload, ) @@ -1516,13 +1544,1202 @@ async def wait_until_first_connect(self) -> None: """ await self._first_connect.wait() + # all these overloads are autogenerated, see ./scripts + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.connect, "connect"], + *, + check: Optional[Callable[[], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, None]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.disconnect, "disconnect"], + *, + check: Optional[Callable[[], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, None]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.ready, "ready"], + *, + check: Optional[Callable[[], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, None]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.resumed, "resumed"], + *, + check: Optional[Callable[[], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, None]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.shard_connect, "shard_connect"], + *, + check: Optional[Callable[[int], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, int]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.shard_disconnect, "shard_disconnect"], + *, + check: Optional[Callable[[int], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, int]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.shard_ready, "shard_ready"], + *, + check: Optional[Callable[[int], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, int]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.shard_resumed, "shard_resumed"], + *, + check: Optional[Callable[[int], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, int]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.socket_event_type, "socket_event_type"], + *, + check: Optional[Callable[[str], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, str]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.socket_raw_receive, "socket_raw_receive"], + *, + check: Optional[Callable[[str], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, str]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.socket_raw_send, "socket_raw_send"], + *, + check: Optional[Callable[[Union[str, bytes]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Union[str, bytes]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_channel_create, "guild_channel_create"], + *, + check: Optional[Callable[[GuildChannel], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, GuildChannel]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_channel_update, "guild_channel_update"], + *, + check: Optional[Callable[[GuildChannel, GuildChannel], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[GuildChannel, GuildChannel]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_channel_delete, "guild_channel_delete"], + *, + check: Optional[Callable[[GuildChannel], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, GuildChannel]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_channel_pins_update, "guild_channel_pins_update"], + *, + check: Optional[Callable[[Union[GuildChannel, Thread], Optional[datetime]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Union[GuildChannel, Thread], Optional[datetime]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.invite_create, "invite_create"], + *, + check: Optional[Callable[[Invite], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Invite]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.invite_delete, "invite_delete"], + *, + check: Optional[Callable[[Invite], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Invite]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.private_channel_update, "private_channel_update"], + *, + check: Optional[Callable[[GroupChannel, GroupChannel], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[GroupChannel, GroupChannel]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.private_channel_pins_update, "private_channel_pins_update"], + *, + check: Optional[Callable[[PrivateChannel, Optional[datetime]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[PrivateChannel, Optional[datetime]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.webhooks_update, "webhooks_update"], + *, + check: Optional[Callable[[GuildChannel], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, GuildChannel]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_create, "thread_create"], + *, + check: Optional[Callable[[Thread], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Thread]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_update, "thread_update"], + *, + check: Optional[Callable[[Thread, Thread], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Thread, Thread]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_delete, "thread_delete"], + *, + check: Optional[Callable[[Thread], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Thread]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_join, "thread_join"], + *, + check: Optional[Callable[[Thread], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Thread]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_remove, "thread_remove"], + *, + check: Optional[Callable[[Thread], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Thread]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_member_join, "thread_member_join"], + *, + check: Optional[Callable[[ThreadMember], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ThreadMember]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.thread_member_remove, "thread_member_remove"], + *, + check: Optional[Callable[[ThreadMember], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ThreadMember]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_thread_member_remove, "raw_thread_member_remove"], + *, + check: Optional[Callable[[RawThreadMemberRemoveEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawThreadMemberRemoveEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_thread_update, "raw_thread_update"], + *, + check: Optional[Callable[[Thread], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Thread]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_thread_delete, "raw_thread_delete"], + *, + check: Optional[Callable[[RawThreadDeleteEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawThreadDeleteEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_join, "guild_join"], + *, + check: Optional[Callable[[Guild], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Guild]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_remove, "guild_remove"], + *, + check: Optional[Callable[[Guild], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Guild]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_update, "guild_update"], + *, + check: Optional[Callable[[Guild, Guild], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Guild, Guild]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_available, "guild_available"], + *, + check: Optional[Callable[[Guild], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Guild]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_unavailable, "guild_unavailable"], + *, + check: Optional[Callable[[Guild], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Guild]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_role_create, "guild_role_create"], + *, + check: Optional[Callable[[Role], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Role]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_role_delete, "guild_role_delete"], + *, + check: Optional[Callable[[Role], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Role]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_role_update, "guild_role_update"], + *, + check: Optional[Callable[[Role, Role], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Role, Role]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_emojis_update, "guild_emojis_update"], + *, + check: Optional[Callable[[Guild, Sequence[Emoji], Sequence[Emoji]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Guild, Sequence[Emoji], Sequence[Emoji]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_stickers_update, "guild_stickers_update"], + *, + check: Optional[ + Callable[[Guild, Sequence[GuildSticker], Sequence[GuildSticker]], bool] + ] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Guild, Sequence[GuildSticker], Sequence[GuildSticker]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_integrations_update, "guild_integrations_update"], + *, + check: Optional[Callable[[Guild], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Guild]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_scheduled_event_create, "guild_scheduled_event_create"], + *, + check: Optional[Callable[[GuildScheduledEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, GuildScheduledEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_scheduled_event_update, "guild_scheduled_event_update"], + *, + check: Optional[Callable[[GuildScheduledEvent, GuildScheduledEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[GuildScheduledEvent, GuildScheduledEvent]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_scheduled_event_delete, "guild_scheduled_event_delete"], + *, + check: Optional[Callable[[GuildScheduledEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, GuildScheduledEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.guild_scheduled_event_subscribe, "guild_scheduled_event_subscribe"], + *, + check: Optional[Callable[[GuildScheduledEvent, Union[Member, User]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[GuildScheduledEvent, Union[Member, User]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[ + Event.guild_scheduled_event_unsubscribe, "guild_scheduled_event_unsubscribe" + ], + *, + check: Optional[Callable[[GuildScheduledEvent, Union[Member, User]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[GuildScheduledEvent, Union[Member, User]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[ + Event.raw_guild_scheduled_event_subscribe, "raw_guild_scheduled_event_subscribe" + ], + *, + check: Optional[Callable[[RawGuildScheduledEventUserActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawGuildScheduledEventUserActionEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[ + Event.raw_guild_scheduled_event_unsubscribe, "raw_guild_scheduled_event_unsubscribe" + ], + *, + check: Optional[Callable[[RawGuildScheduledEventUserActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawGuildScheduledEventUserActionEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[ + Event.application_command_permissions_update, "application_command_permissions_update" + ], + *, + check: Optional[Callable[[GuildApplicationCommandPermissions], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, GuildApplicationCommandPermissions]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.automod_action_execution, "automod_action_execution"], + *, + check: Optional[Callable[[AutoModActionExecution], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, AutoModActionExecution]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.automod_rule_create, "automod_rule_create"], + *, + check: Optional[Callable[[AutoModRule], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, AutoModRule]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.automod_rule_update, "automod_rule_update"], + *, + check: Optional[Callable[[AutoModRule], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, AutoModRule]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.automod_rule_delete, "automod_rule_delete"], + *, + check: Optional[Callable[[AutoModRule], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, AutoModRule]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.audit_log_entry_create, "audit_log_entry_create"], + *, + check: Optional[Callable[[AuditLogEntry], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, AuditLogEntry]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.integration_create, "integration_create"], + *, + check: Optional[Callable[[Integration], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Integration]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.integration_update, "integration_update"], + *, + check: Optional[Callable[[Integration], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Integration]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_integration_delete, "raw_integration_delete"], + *, + check: Optional[Callable[[RawIntegrationDeleteEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawIntegrationDeleteEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.member_join, "member_join"], + *, + check: Optional[Callable[[Member], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Member]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.member_remove, "member_remove"], + *, + check: Optional[Callable[[Member], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Member]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.member_update, "member_update"], + *, + check: Optional[Callable[[Member, Member], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Member, Member]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_member_remove, "raw_member_remove"], + *, + check: Optional[Callable[[RawGuildMemberRemoveEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawGuildMemberRemoveEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_member_update, "raw_member_update"], + *, + check: Optional[Callable[[Member], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Member]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.member_ban, "member_ban"], + *, + check: Optional[Callable[[Guild, Union[User, Member]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Guild, Union[User, Member]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.member_unban, "member_unban"], + *, + check: Optional[Callable[[Guild, User], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Guild, User]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.presence_update, "presence_update"], + *, + check: Optional[Callable[[Member, Member], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Member, Member]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.user_update, "user_update"], + *, + check: Optional[Callable[[User, User], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[User, User]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.voice_state_update, "voice_state_update"], + *, + check: Optional[Callable[[Member, VoiceState, VoiceState], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Member, VoiceState, VoiceState]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.stage_instance_create, "stage_instance_create"], + *, + check: Optional[Callable[[StageInstance], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, StageInstance]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.stage_instance_delete, "stage_instance_delete"], + *, + check: Optional[Callable[[StageInstance, StageInstance], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[StageInstance, StageInstance]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.stage_instance_update, "stage_instance_update"], + *, + check: Optional[Callable[[StageInstance], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, StageInstance]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.application_command, "application_command"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.application_command_autocomplete, "application_command_autocomplete"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.button_click, "button_click"], + *, + check: Optional[Callable[[MessageInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, MessageInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.dropdown, "dropdown"], + *, + check: Optional[Callable[[MessageInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, MessageInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.interaction, "interaction"], + *, + check: Optional[Callable[[Interaction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Interaction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message_interaction, "message_interaction"], + *, + check: Optional[Callable[[MessageInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, MessageInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.modal_submit, "modal_submit"], + *, + check: Optional[Callable[[ModalInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ModalInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message, "message"], + *, + check: Optional[Callable[[Message], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Message]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message_edit, "message_edit"], + *, + check: Optional[Callable[[Message, Message], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Message, Message]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message_delete, "message_delete"], + *, + check: Optional[Callable[[Message], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Message]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.bulk_message_delete, "bulk_message_delete"], + *, + check: Optional[Callable[[List[Message]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, List[Message]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_message_edit, "raw_message_edit"], + *, + check: Optional[Callable[[RawMessageUpdateEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawMessageUpdateEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_message_delete, "raw_message_delete"], + *, + check: Optional[Callable[[RawMessageDeleteEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawMessageDeleteEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_bulk_message_delete, "raw_bulk_message_delete"], + *, + check: Optional[Callable[[RawBulkMessageDeleteEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawBulkMessageDeleteEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.reaction_add, "reaction_add"], + *, + check: Optional[Callable[[Reaction, Union[Member, User]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Reaction, Union[Member, User]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.reaction_remove, "reaction_remove"], + *, + check: Optional[Callable[[Reaction, Union[Member, User]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Reaction, Union[Member, User]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.reaction_clear, "reaction_clear"], + *, + check: Optional[Callable[[Message, List[Reaction]], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Message, List[Reaction]]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.reaction_clear_emoji, "reaction_clear_emoji"], + *, + check: Optional[Callable[[Reaction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Reaction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_reaction_add, "raw_reaction_add"], + *, + check: Optional[Callable[[RawReactionActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawReactionActionEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_reaction_remove, "raw_reaction_remove"], + *, + check: Optional[Callable[[RawReactionActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawReactionActionEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_reaction_clear, "raw_reaction_clear"], + *, + check: Optional[Callable[[RawReactionClearEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawReactionClearEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_reaction_clear_emoji, "raw_reaction_clear_emoji"], + *, + check: Optional[Callable[[RawReactionClearEmojiEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawReactionClearEmojiEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.typing, "typing"], + *, + check: Optional[ + Callable[[Union[Messageable, ForumChannel], Union[User, Member], datetime], bool] + ] = None, + timeout: Optional[float] = None, + ) -> Coroutine[ + Any, Any, Tuple[Union[Messageable, ForumChannel], Union[User, Member], datetime] + ]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_typing, "raw_typing"], + *, + check: Optional[Callable[[RawTypingEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawTypingEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.command, "command"], + *, + check: Optional[Callable[[commands.Context], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, commands.Context]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.command_completion, "command_completion"], + *, + check: Optional[Callable[[commands.Context], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, commands.Context]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.command_error, "command_error"], + *, + check: Optional[Callable[[commands.Context, commands.CommandError], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[commands.Context, commands.CommandError]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.slash_command, "slash_command"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.slash_command_completion, "slash_command_completion"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.slash_command_error, "slash_command_error"], + *, + check: Optional[ + Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + ] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.user_command, "user_command"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.user_command_completion, "user_command_completion"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.user_command_error, "user_command_error"], + *, + check: Optional[ + Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + ] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message_command, "message_command"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message_command_completion, "message_command_completion"], + *, + check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.message_command_error, "message_command_error"], + *, + check: Optional[ + Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + ] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ... + + @_overload_with_events def wait_for( self, event: Union[str, Event], *, check: Optional[Callable[..., bool]] = None, timeout: Optional[float] = None, - ) -> Any: + ) -> Coroutine[Any, Any, Any]: """|coro| Waits for a WebSocket event to be dispatched. diff --git a/disnake/utils.py b/disnake/utils.py index 54781da834..352550646d 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1331,13 +1331,15 @@ def assert_never(arg: Never, /) -> None: pass -# n.b. This must be imported and used as @ _overload_with_permissions (without the space) -# this is used by the libcst parser and has no runtime purpose -# it is merely a marker not unlike pytest.mark -def _overload_with_permissions(func: T) -> T: +# n.b. These must be imported and used as @. +# This is used by the libcst parser and has no runtime purpose; +# it is merely a marker not unlike pytest.mark. +def _noop_decorator(func: T) -> T: return func # this is used as a marker for functions or classes that were created by codemodding -def _generated(func: T) -> T: - return func +_generated = _noop_decorator + +_overload_with_permissions = _noop_decorator +_overload_with_events = _noop_decorator diff --git a/pyproject.toml b/pyproject.toml index 99bdec9556..995a86a0bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,6 +216,7 @@ ignore = [ "PT", # this is not a module of pytest tests ] "tests/*.py" = ["S101"] # use of assert is okay in test files +"scripts/*.py" = ["S101"] # use of assert is okay in codemods # we are not using noqa in the example files themselves "examples/*.py" = [ "B008", # do not perform function calls in argument defaults, this is how most commands work diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py new file mode 100644 index 0000000000..3189342f66 --- /dev/null +++ b/scripts/codemods/typed_events.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: MIT + +import types +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, cast + +import libcst as cst +import libcst.matchers as m +from libcst import codemod + +from disnake import Event + + +def get_param(func: cst.FunctionDef, name: str) -> cst.Param: + results = m.findall(func.params, m.Param(m.Name(name))) + assert len(results) == 1 + return cast(cst.Param, results[0]) + + +class EventTypings(codemod.VisitorBasedCodemodCommand): + DESCRIPTION: str = "Adds overloads for library events." + + flag_classes: List[str] + imported_module: types.ModuleType + + def transform_module(self, tree: cst.Module) -> cst.Module: + if "@_overload_with_events" not in tree.code: + raise codemod.SkipFile( + "this module does not contain the required decorator: `@_overload_with_events`." + ) + return super().transform_module(tree) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: + # don't recurse into the body of a function + return False + + def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): + decorators = [ + deco.decorator.value + for deco in node.decorators + if not isinstance(deco.decorator, cst.Call) + ] + if "_generated" in decorators: + # remove generated methods + return cst.RemovalSentinel.REMOVE + if "_overload_with_events" not in decorators: + # ignore + return node + + if node.name.value != "wait_for": + raise RuntimeError( + f"unknown method '{node.name.value}' with @_overload_with_events decorator" + ) + + # if we're here, we found a @_overload_with_events decorator + new_overloads: List[cst.FunctionDef] = [] + for event in EVENT_DATA.keys(): + event_data = EVENT_DATA[event] + if event_data.event_only: + continue + args = event_data.args + + new_overload = node.with_changes( + body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), + decorators=[ + cst.Decorator(cst.Name("overload")), + cst.Decorator(cst.Name("_generated")), + ], + ) + + # set `event` annotation + new_annotation = cst.parse_expression( + # the lazy way of doing things + f'Literal[Event.{event.name}, "{event.value}"]', + config=self.module.config_for_parsing, + ) + new_overload = new_overload.with_deep_changes( + get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) + ) + + # set `check` annotation + callable_annotation = m.findall( + get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) + )[0] + callable_params = m.findall(callable_annotation, m.Ellipsis())[0] + new_annotation = cst.parse_expression( + f'[{",".join(args)}]', + config=self.module.config_for_parsing, + ) + new_overload = new_overload.deep_replace(callable_params, new_annotation) + + # set return annotation + if len(args) == 0: + new_annotation_str = "None" + elif len(args) == 1: + new_annotation_str = args[0] + else: + new_annotation_str = f'Tuple[{",".join(args)}]' + new_annotation = cst.parse_expression( + f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing + ) + new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) + + new_overloads.append(new_overload) + + return cst.FlattenSentinel([*new_overloads, node]) + + +@dataclass +class EventData: + # type names of event arguments, e.g. `("Guild", "User")` + args: Tuple[str, ...] + # whether the event is specific to ext.commands + bot: bool = False + # whether the event can only be used through `@event` and not other listeners + event_only: bool = False + + +EVENT_DATA: Dict[Event, EventData] = { + Event.connect: EventData(()), + Event.disconnect: EventData(()), + # TODO: figure out how to specify varargs for these two + Event.error: EventData((), event_only=True), + Event.gateway_error: EventData((), event_only=True), + Event.ready: EventData(()), + Event.resumed: EventData(()), + Event.shard_connect: EventData(("int",)), + Event.shard_disconnect: EventData(("int",)), + Event.shard_ready: EventData(("int",)), + Event.shard_resumed: EventData(("int",)), + Event.socket_event_type: EventData(("str",)), + Event.socket_raw_receive: EventData(("str",)), + Event.socket_raw_send: EventData(("Union[str, bytes]",)), + Event.guild_channel_create: EventData(("GuildChannel",)), + Event.guild_channel_update: EventData(("GuildChannel", "GuildChannel")), + Event.guild_channel_delete: EventData(("GuildChannel",)), + Event.guild_channel_pins_update: EventData( + ("Union[GuildChannel, Thread]", "Optional[datetime]") + ), + Event.invite_create: EventData(("Invite",)), + Event.invite_delete: EventData(("Invite",)), + Event.private_channel_update: EventData(("GroupChannel", "GroupChannel")), + Event.private_channel_pins_update: EventData(("PrivateChannel", "Optional[datetime]")), + Event.webhooks_update: EventData(("GuildChannel",)), + Event.thread_create: EventData(("Thread",)), + Event.thread_update: EventData(("Thread", "Thread")), + Event.thread_delete: EventData(("Thread",)), + Event.thread_join: EventData(("Thread",)), + Event.thread_remove: EventData(("Thread",)), + Event.thread_member_join: EventData(("ThreadMember",)), + Event.thread_member_remove: EventData(("ThreadMember",)), + Event.raw_thread_member_remove: EventData(("RawThreadMemberRemoveEvent",)), + Event.raw_thread_update: EventData(("Thread",)), + Event.raw_thread_delete: EventData(("RawThreadDeleteEvent",)), + Event.guild_join: EventData(("Guild",)), + Event.guild_remove: EventData(("Guild",)), + Event.guild_update: EventData(("Guild", "Guild")), + Event.guild_available: EventData(("Guild",)), + Event.guild_unavailable: EventData(("Guild",)), + Event.guild_role_create: EventData(("Role",)), + Event.guild_role_delete: EventData(("Role",)), + Event.guild_role_update: EventData(("Role", "Role")), + Event.guild_emojis_update: EventData(("Guild", "Sequence[Emoji]", "Sequence[Emoji]")), + Event.guild_stickers_update: EventData( + ("Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]") + ), + Event.guild_integrations_update: EventData(("Guild",)), + Event.guild_scheduled_event_create: EventData(("GuildScheduledEvent",)), + Event.guild_scheduled_event_update: EventData(("GuildScheduledEvent", "GuildScheduledEvent")), + Event.guild_scheduled_event_delete: EventData(("GuildScheduledEvent",)), + Event.guild_scheduled_event_subscribe: EventData( + ("GuildScheduledEvent", "Union[Member, User]") + ), + Event.guild_scheduled_event_unsubscribe: EventData( + ("GuildScheduledEvent", "Union[Member, User]") + ), + Event.raw_guild_scheduled_event_subscribe: EventData( + ("RawGuildScheduledEventUserActionEvent",) + ), + Event.raw_guild_scheduled_event_unsubscribe: EventData( + ("RawGuildScheduledEventUserActionEvent",) + ), + Event.application_command_permissions_update: EventData( + ("GuildApplicationCommandPermissions",) + ), + Event.automod_action_execution: EventData(("AutoModActionExecution",)), + Event.automod_rule_create: EventData(("AutoModRule",)), + Event.automod_rule_update: EventData(("AutoModRule",)), + Event.automod_rule_delete: EventData(("AutoModRule",)), + Event.audit_log_entry_create: EventData(("AuditLogEntry",)), + Event.integration_create: EventData(("Integration",)), + Event.integration_update: EventData(("Integration",)), + Event.raw_integration_delete: EventData(("RawIntegrationDeleteEvent",)), + Event.member_join: EventData(("Member",)), + Event.member_remove: EventData(("Member",)), + Event.member_update: EventData(("Member", "Member")), + Event.raw_member_remove: EventData(("RawGuildMemberRemoveEvent",)), + Event.raw_member_update: EventData(("Member",)), + Event.member_ban: EventData(("Guild", "Union[User, Member]")), + Event.member_unban: EventData(("Guild", "User")), + Event.presence_update: EventData(("Member", "Member")), + Event.user_update: EventData(("User", "User")), + Event.voice_state_update: EventData(("Member", "VoiceState", "VoiceState")), + Event.stage_instance_create: EventData(("StageInstance",)), + Event.stage_instance_delete: EventData(("StageInstance", "StageInstance")), + Event.stage_instance_update: EventData(("StageInstance",)), + Event.application_command: EventData(("ApplicationCommandInteraction",)), + Event.application_command_autocomplete: EventData(("ApplicationCommandInteraction",)), + Event.button_click: EventData(("MessageInteraction",)), + Event.dropdown: EventData(("MessageInteraction",)), + Event.interaction: EventData(("Interaction",)), + Event.message_interaction: EventData(("MessageInteraction",)), + Event.modal_submit: EventData(("ModalInteraction",)), + Event.message: EventData(("Message",)), + Event.message_edit: EventData(("Message", "Message")), + Event.message_delete: EventData(("Message",)), + Event.bulk_message_delete: EventData(("List[Message]",)), + Event.raw_message_edit: EventData(("RawMessageUpdateEvent",)), + Event.raw_message_delete: EventData(("RawMessageDeleteEvent",)), + Event.raw_bulk_message_delete: EventData(("RawBulkMessageDeleteEvent",)), + Event.reaction_add: EventData(("Reaction", "Union[Member, User]")), + Event.reaction_remove: EventData(("Reaction", "Union[Member, User]")), + Event.reaction_clear: EventData(("Message", "List[Reaction]")), + Event.reaction_clear_emoji: EventData(("Reaction",)), + Event.raw_reaction_add: EventData(("RawReactionActionEvent",)), + Event.raw_reaction_remove: EventData(("RawReactionActionEvent",)), + Event.raw_reaction_clear: EventData(("RawReactionClearEvent",)), + Event.raw_reaction_clear_emoji: EventData(("RawReactionClearEmojiEvent",)), + Event.typing: EventData( + ("Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime") + ), + Event.raw_typing: EventData(("RawTypingEvent",)), + Event.command: EventData(("commands.Context",), bot=True), + Event.command_completion: EventData(("commands.Context",), bot=True), + Event.command_error: EventData(("commands.Context", "commands.CommandError"), bot=True), + Event.slash_command: EventData(("ApplicationCommandInteraction",), bot=True), + Event.slash_command_completion: EventData(("ApplicationCommandInteraction",), bot=True), + Event.slash_command_error: EventData( + ("ApplicationCommandInteraction", "commands.CommandError"), bot=True + ), + Event.user_command: EventData(("ApplicationCommandInteraction",), bot=True), + Event.user_command_completion: EventData(("ApplicationCommandInteraction",), bot=True), + Event.user_command_error: EventData( + ("ApplicationCommandInteraction", "commands.CommandError"), bot=True + ), + Event.message_command: EventData(("ApplicationCommandInteraction",), bot=True), + Event.message_command_completion: EventData(("ApplicationCommandInteraction",), bot=True), + Event.message_command_error: EventData( + ("ApplicationCommandInteraction", "commands.CommandError"), bot=True + ), +} From a7da890832c36c32adcdb28ed36e4e7597f58f69 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Wed, 26 Apr 2023 22:02:00 +0200 Subject: [PATCH 02/22] feat: restrict ext.commands-specific overloads to bot classes --- disnake/client.py | 33 ++++++++++++++++++-------------- scripts/codemods/typed_events.py | 7 +++++++ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/disnake/client.py b/disnake/client.py index c84012a031..145a378531 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -118,6 +118,13 @@ from .types.gateway import SessionStartLimit as SessionStartLimitPayload from .voice_client import VoiceProtocol + AnyBot = Union[ + commands.Bot, + commands.AutoShardedBot, + commands.InteractionBot, + commands.AutoShardedInteractionBot, + ] + __all__ = ( "Client", @@ -1544,8 +1551,6 @@ async def wait_until_first_connect(self) -> None: """ await self._first_connect.wait() - # all these overloads are autogenerated, see ./scripts - @overload @_generated def wait_for( @@ -2597,7 +2602,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.command, "command"], *, check: Optional[Callable[[commands.Context], bool]] = None, @@ -2608,7 +2613,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.command_completion, "command_completion"], *, check: Optional[Callable[[commands.Context], bool]] = None, @@ -2619,7 +2624,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.command_error, "command_error"], *, check: Optional[Callable[[commands.Context, commands.CommandError], bool]] = None, @@ -2630,7 +2635,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.slash_command, "slash_command"], *, check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, @@ -2641,7 +2646,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.slash_command_completion, "slash_command_completion"], *, check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, @@ -2652,7 +2657,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.slash_command_error, "slash_command_error"], *, check: Optional[ @@ -2665,7 +2670,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.user_command, "user_command"], *, check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, @@ -2676,7 +2681,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.user_command_completion, "user_command_completion"], *, check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, @@ -2687,7 +2692,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.user_command_error, "user_command_error"], *, check: Optional[ @@ -2700,7 +2705,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.message_command, "message_command"], *, check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, @@ -2711,7 +2716,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.message_command_completion, "message_command_completion"], *, check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, @@ -2722,7 +2727,7 @@ def wait_for( @overload @_generated def wait_for( - self, + self: AnyBot, event: Literal[Event.message_command_error, "message_command_error"], *, check: Optional[ diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 3189342f66..52c46444ca 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -101,6 +101,13 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): ) new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) + # set `self` annotation as a workaround for overloads in subclasses + if event_data.bot: + new_overload = new_overload.with_deep_changes( + get_param(new_overload, "self"), + annotation=cst.Annotation(cst.Name("AnyBot")), + ) + new_overloads.append(new_overload) return cst.FlattenSentinel([*new_overloads, node]) From 2058a7495b7b0718f17ff1b1a91907c4edbfb810 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 27 Apr 2023 12:58:48 +0200 Subject: [PATCH 03/22] feat: add `wait_for` fallback for `event: str` This is not ideal, since it bypasses the overloads if there's a typing issue in your wait_for call when using a str event. When using `Event.*`, the `check` and return types are still enforced as they should be. --- disnake/client.py | 12 ++++++++++++ scripts/codemods/typed_events.py | 3 ++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/disnake/client.py b/disnake/client.py index 145a378531..197b8005a9 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -2737,7 +2737,19 @@ def wait_for( ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: ... + # fallback for custom events + + @overload @_overload_with_events + def wait_for( + self, + event: str, + *, + check: Optional[Callable[..., bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Any]: + ... + def wait_for( self, event: Union[str, Event], diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 52c46444ca..de57083041 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -54,7 +54,7 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): # if we're here, we found a @_overload_with_events decorator new_overloads: List[cst.FunctionDef] = [] - for event in EVENT_DATA.keys(): + for event in Event: event_data = EVENT_DATA[event] if event_data.event_only: continue @@ -66,6 +66,7 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): cst.Decorator(cst.Name("overload")), cst.Decorator(cst.Name("_generated")), ], + leading_lines=(), ) # set `event` annotation From 0b19feb34cdf13785aa0a48a55ea7b9556d9ba29 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 27 Apr 2023 13:13:36 +0200 Subject: [PATCH 04/22] test: add typing tests --- tests/test_events.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index 14e6a649c4..e6ce272c39 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: MIT -from typing import Any +from typing import Any, Coroutine import pytest +from typing_extensions import assert_type import disnake from disnake import Event @@ -47,6 +48,32 @@ def test_wait_for(bot: commands.Bot, event) -> None: coro.close() # close coroutine to avoid warning +def _test_typing_wait_for() -> None: + expected_type = Coroutine[Any, Any, disnake.Guild] + client = disnake.Client() + + # valid enum event + _ = assert_type(client.wait_for(Event.guild_join), expected_type) + _ = assert_type(client.wait_for(Event.guild_join, check=lambda g: True), expected_type) + + # valid str event + _ = assert_type(client.wait_for("guild_join"), expected_type) + _ = assert_type(client.wait_for("guild_join", check=lambda g: True), expected_type) + + # invalid check type + _ = client.wait_for(Event.guild_join, check=lambda: True) # type: ignore + # n.b. this one isn't ideal, but there's no way to prevent type-checkers from using the fallback in this case + _ = assert_type(client.wait_for("guild_join", check=lambda: True), Coroutine[Any, Any, Any]) + + # bot-specific events + bot = commands.Bot(command_prefix=commands.when_mentioned) + _ = client.wait_for(Event.slash_command_error) # type: ignore + _ = assert_type( + bot.wait_for(Event.slash_command), + Coroutine[Any, Any, disnake.ApplicationCommandInteraction], + ) + + # Bot.add_listener / Bot.remove_listener From 2e5d765086d371bcfc5ded52e21d80ac065a0351 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 27 Apr 2023 13:30:58 +0200 Subject: [PATCH 05/22] refactor: move thingy to a separate method --- scripts/codemods/typed_events.py | 103 +++++++++++++++++-------------- 1 file changed, 56 insertions(+), 47 deletions(-) diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index de57083041..8ba0f93165 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: MIT +from __future__ import annotations + import types from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, cast @@ -58,60 +60,67 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): event_data = EVENT_DATA[event] if event_data.event_only: continue - args = event_data.args + new_overloads.append(self.generate_wait_for_overload(node, event, event_data)) - new_overload = node.with_changes( - body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), - decorators=[ - cst.Decorator(cst.Name("overload")), - cst.Decorator(cst.Name("_generated")), - ], - leading_lines=(), - ) + return cst.FlattenSentinel([*new_overloads, node]) - # set `event` annotation - new_annotation = cst.parse_expression( - # the lazy way of doing things - f'Literal[Event.{event.name}, "{event.value}"]', - config=self.module.config_for_parsing, - ) - new_overload = new_overload.with_deep_changes( - get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) - ) + def generate_wait_for_overload( + self, func: cst.FunctionDef, event: Event, event_data: EventData + ) -> cst.FunctionDef: + args = event_data.args - # set `check` annotation - callable_annotation = m.findall( - get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) - )[0] - callable_params = m.findall(callable_annotation, m.Ellipsis())[0] - new_annotation = cst.parse_expression( - f'[{",".join(args)}]', - config=self.module.config_for_parsing, - ) - new_overload = new_overload.deep_replace(callable_params, new_annotation) + new_overload = func.with_changes( + body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), + decorators=[ + cst.Decorator(cst.Name("overload")), + cst.Decorator(cst.Name("_generated")), + ], + leading_lines=(), + ) - # set return annotation - if len(args) == 0: - new_annotation_str = "None" - elif len(args) == 1: - new_annotation_str = args[0] - else: - new_annotation_str = f'Tuple[{",".join(args)}]' - new_annotation = cst.parse_expression( - f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing - ) - new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) + # set `event` annotation + new_annotation = cst.parse_expression( + # the lazy way of doing things + f'Literal[Event.{event.name}, "{event.value}"]', + config=self.module.config_for_parsing, + ) + new_overload = new_overload.with_deep_changes( + get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) + ) - # set `self` annotation as a workaround for overloads in subclasses - if event_data.bot: - new_overload = new_overload.with_deep_changes( - get_param(new_overload, "self"), - annotation=cst.Annotation(cst.Name("AnyBot")), - ) + # set `check` annotation + callable_annotation = m.findall( + get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) + )[0] + callable_params = m.findall(callable_annotation, m.Ellipsis())[0] + new_annotation = cst.parse_expression( + f'[{",".join(args)}]', + config=self.module.config_for_parsing, + ) + new_overload = cast( + cst.FunctionDef, new_overload.deep_replace(callable_params, new_annotation) + ) - new_overloads.append(new_overload) + # set return annotation + if len(args) == 0: + new_annotation_str = "None" + elif len(args) == 1: + new_annotation_str = args[0] + else: + new_annotation_str = f'Tuple[{",".join(args)}]' + new_annotation = cst.parse_expression( + f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing + ) + new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) - return cst.FlattenSentinel([*new_overloads, node]) + # set `self` annotation as a workaround for overloads in subclasses + if event_data.bot: + new_overload = new_overload.with_deep_changes( + get_param(new_overload, "self"), + annotation=cst.Annotation(cst.Name("AnyBot")), + ) + + return new_overload @dataclass From 8c29ed55bb304b96aa04065c579bd3c57661d015 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 28 Apr 2023 14:51:30 +0200 Subject: [PATCH 06/22] refactor: make codemod more extensible --- scripts/codemods/typed_events.py | 49 +++++++++++++++++++------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 8ba0f93165..846675bc01 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -49,7 +49,9 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): # ignore return node - if node.name.value != "wait_for": + if node.name.value == "wait_for": + generator = self.generate_wait_for_overload + else: raise RuntimeError( f"unknown method '{node.name.value}' with @_overload_with_events decorator" ) @@ -60,16 +62,12 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): event_data = EVENT_DATA[event] if event_data.event_only: continue - new_overloads.append(self.generate_wait_for_overload(node, event, event_data)) + new_overloads.append(generator(node, event, event_data)) return cst.FlattenSentinel([*new_overloads, node]) - def generate_wait_for_overload( - self, func: cst.FunctionDef, event: Event, event_data: EventData - ) -> cst.FunctionDef: - args = event_data.args - - new_overload = func.with_changes( + def create_empty_overload(self, func: cst.FunctionDef) -> cst.FunctionDef: + return func.with_changes( body=cst.IndentedBlock([cst.SimpleStatementLine([cst.Expr(cst.Ellipsis())])]), decorators=[ cst.Decorator(cst.Name("overload")), @@ -78,14 +76,29 @@ def generate_wait_for_overload( leading_lines=(), ) - # set `event` annotation - new_annotation = cst.parse_expression( - # the lazy way of doing things + def create_literal(self, event: Event) -> cst.BaseExpression: + return cst.parse_expression( f'Literal[Event.{event.name}, "{event.value}"]', config=self.module.config_for_parsing, ) + + def create_args_list(self, event_data: EventData) -> cst.BaseExpression: + return cst.parse_expression( + f'[{",".join(event_data.args)}]', + config=self.module.config_for_parsing, + ) + + def generate_wait_for_overload( + self, func: cst.FunctionDef, event: Event, event_data: EventData + ) -> cst.FunctionDef: + args = event_data.args + + new_overload = self.create_empty_overload(func) + + # set `event` annotation new_overload = new_overload.with_deep_changes( - get_param(new_overload, "event"), annotation=cst.Annotation(new_annotation) + get_param(new_overload, "event"), + annotation=cst.Annotation(self.create_literal(event)), ) # set `check` annotation @@ -93,12 +106,9 @@ def generate_wait_for_overload( get_param(new_overload, "check"), m.Subscript(m.Name("Callable")) )[0] callable_params = m.findall(callable_annotation, m.Ellipsis())[0] - new_annotation = cst.parse_expression( - f'[{",".join(args)}]', - config=self.module.config_for_parsing, - ) new_overload = cast( - cst.FunctionDef, new_overload.deep_replace(callable_params, new_annotation) + cst.FunctionDef, + new_overload.deep_replace(callable_params, self.create_args_list(event_data)), ) # set return annotation @@ -109,7 +119,8 @@ def generate_wait_for_overload( else: new_annotation_str = f'Tuple[{",".join(args)}]' new_annotation = cst.parse_expression( - f"Coroutine[Any, Any, {new_annotation_str}]", config=self.module.config_for_parsing + f"Coroutine[Any, Any, {new_annotation_str}]", + config=self.module.config_for_parsing, ) new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) @@ -136,7 +147,7 @@ class EventData: EVENT_DATA: Dict[Event, EventData] = { Event.connect: EventData(()), Event.disconnect: EventData(()), - # TODO: figure out how to specify varargs for these two + # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event Event.error: EventData((), event_only=True), Event.gateway_error: EventData((), event_only=True), Event.ready: EventData(()), From d8961933706a0a703b0630b0f1f33f0691d90d2f Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 28 Apr 2023 15:55:47 +0200 Subject: [PATCH 07/22] feat: add better error message for missing events --- scripts/codemods/typed_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 846675bc01..148748d9bd 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -59,7 +59,8 @@ def leave_FunctionDef(self, _: cst.FunctionDef, node: cst.FunctionDef): # if we're here, we found a @_overload_with_events decorator new_overloads: List[cst.FunctionDef] = [] for event in Event: - event_data = EVENT_DATA[event] + if not (event_data := EVENT_DATA.get(event)): + raise RuntimeError(f"{event} is missing an EVENT_DATA definition") if event_data.event_only: continue new_overloads.append(generator(node, event, event_data)) From cf276c745e4cc9e9776837cbc80336a47a5d2723 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 28 Apr 2023 17:49:47 +0200 Subject: [PATCH 08/22] chore: stuff --- tests/test_events.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index e6ce272c39..bba4682f42 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -48,9 +48,8 @@ def test_wait_for(bot: commands.Bot, event) -> None: coro.close() # close coroutine to avoid warning -def _test_typing_wait_for() -> None: +def _test_typing_wait_for(client: disnake.Client, bot: commands.Bot) -> None: expected_type = Coroutine[Any, Any, disnake.Guild] - client = disnake.Client() # valid enum event _ = assert_type(client.wait_for(Event.guild_join), expected_type) @@ -66,8 +65,7 @@ def _test_typing_wait_for() -> None: _ = assert_type(client.wait_for("guild_join", check=lambda: True), Coroutine[Any, Any, Any]) # bot-specific events - bot = commands.Bot(command_prefix=commands.when_mentioned) - _ = client.wait_for(Event.slash_command_error) # type: ignore + _ = client.wait_for(Event.slash_command_error) # type: ignore # this should error _ = assert_type( bot.wait_for(Event.slash_command), Coroutine[Any, Any, disnake.ApplicationCommandInteraction], From f22006ceedcb85ee1ff6b0eee53e5d762c2cb5af Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 30 Apr 2023 16:24:23 +0200 Subject: [PATCH 09/22] refactor: move event data to separate module --- disnake/_event_data.py | 154 +++++++++++++++++++++++++++++++ scripts/codemods/typed_events.py | 153 +----------------------------- 2 files changed, 158 insertions(+), 149 deletions(-) create mode 100644 disnake/_event_data.py diff --git a/disnake/_event_data.py b/disnake/_event_data.py new file mode 100644 index 0000000000..e364e96d71 --- /dev/null +++ b/disnake/_event_data.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from typing import Dict, List, Tuple + +from .enums import Event + + +class EventData: + type_args: Tuple[str, ...] + """Type names of event arguments, e.g. `("Guild", "User")`""" + + bot: bool + """Whether the event is specific to ext.commands""" + + event_only: bool + """Whether the event can only be used through `@event` and not other listeners""" + + def __init__(self, type_args: List[str], bot: bool = False, event_only: bool = False) -> None: + self.type_args = tuple(type_args) + self.bot = bot + self.event_only = event_only + + +EVENT_DATA: Dict[Event, EventData] = { + Event.connect: EventData([]), + Event.disconnect: EventData([]), + # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event + Event.error: EventData([], event_only=True), + Event.gateway_error: EventData([], event_only=True), + Event.ready: EventData([]), + Event.resumed: EventData([]), + Event.shard_connect: EventData(["int"]), + Event.shard_disconnect: EventData(["int"]), + Event.shard_ready: EventData(["int"]), + Event.shard_resumed: EventData(["int"]), + Event.socket_event_type: EventData(["str"]), + Event.socket_raw_receive: EventData(["str"]), + Event.socket_raw_send: EventData(["Union[str, bytes]"]), + Event.guild_channel_create: EventData(["GuildChannel"]), + Event.guild_channel_update: EventData(["GuildChannel", "GuildChannel"]), + Event.guild_channel_delete: EventData(["GuildChannel"]), + Event.guild_channel_pins_update: EventData( + ["Union[GuildChannel, Thread]", "Optional[datetime]"] + ), + Event.invite_create: EventData(["Invite"]), + Event.invite_delete: EventData(["Invite"]), + Event.private_channel_update: EventData(["GroupChannel", "GroupChannel"]), + Event.private_channel_pins_update: EventData(["PrivateChannel", "Optional[datetime]"]), + Event.webhooks_update: EventData(["GuildChannel"]), + Event.thread_create: EventData(["Thread"]), + Event.thread_update: EventData(["Thread", "Thread"]), + Event.thread_delete: EventData(["Thread"]), + Event.thread_join: EventData(["Thread"]), + Event.thread_remove: EventData(["Thread"]), + Event.thread_member_join: EventData(["ThreadMember"]), + Event.thread_member_remove: EventData(["ThreadMember"]), + Event.raw_thread_member_remove: EventData(["RawThreadMemberRemoveEvent"]), + Event.raw_thread_update: EventData(["Thread"]), + Event.raw_thread_delete: EventData(["RawThreadDeleteEvent"]), + Event.guild_join: EventData(["Guild"]), + Event.guild_remove: EventData(["Guild"]), + Event.guild_update: EventData(["Guild", "Guild"]), + Event.guild_available: EventData(["Guild"]), + Event.guild_unavailable: EventData(["Guild"]), + Event.guild_role_create: EventData(["Role"]), + Event.guild_role_delete: EventData(["Role"]), + Event.guild_role_update: EventData(["Role", "Role"]), + Event.guild_emojis_update: EventData(["Guild", "Sequence[Emoji]", "Sequence[Emoji]"]), + Event.guild_stickers_update: EventData( + ["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"] + ), + Event.guild_integrations_update: EventData(["Guild"]), + Event.guild_scheduled_event_create: EventData(["GuildScheduledEvent"]), + Event.guild_scheduled_event_update: EventData(["GuildScheduledEvent", "GuildScheduledEvent"]), + Event.guild_scheduled_event_delete: EventData(["GuildScheduledEvent"]), + Event.guild_scheduled_event_subscribe: EventData( + ["GuildScheduledEvent", "Union[Member, User]"] + ), + Event.guild_scheduled_event_unsubscribe: EventData( + ["GuildScheduledEvent", "Union[Member, User]"] + ), + Event.raw_guild_scheduled_event_subscribe: EventData(["RawGuildScheduledEventUserActionEvent"]), + Event.raw_guild_scheduled_event_unsubscribe: EventData( + ["RawGuildScheduledEventUserActionEvent"] + ), + Event.application_command_permissions_update: EventData(["GuildApplicationCommandPermissions"]), + Event.automod_action_execution: EventData(["AutoModActionExecution"]), + Event.automod_rule_create: EventData(["AutoModRule"]), + Event.automod_rule_update: EventData(["AutoModRule"]), + Event.automod_rule_delete: EventData(["AutoModRule"]), + Event.audit_log_entry_create: EventData(["AuditLogEntry"]), + Event.integration_create: EventData(["Integration"]), + Event.integration_update: EventData(["Integration"]), + Event.raw_integration_delete: EventData(["RawIntegrationDeleteEvent"]), + Event.member_join: EventData(["Member"]), + Event.member_remove: EventData(["Member"]), + Event.member_update: EventData(["Member", "Member"]), + Event.raw_member_remove: EventData(["RawGuildMemberRemoveEvent"]), + Event.raw_member_update: EventData(["Member"]), + Event.member_ban: EventData(["Guild", "Union[User, Member]"]), + Event.member_unban: EventData(["Guild", "User"]), + Event.presence_update: EventData(["Member", "Member"]), + Event.user_update: EventData(["User", "User"]), + Event.voice_state_update: EventData(["Member", "VoiceState", "VoiceState"]), + Event.stage_instance_create: EventData(["StageInstance"]), + Event.stage_instance_delete: EventData(["StageInstance", "StageInstance"]), + Event.stage_instance_update: EventData(["StageInstance"]), + Event.application_command: EventData(["ApplicationCommandInteraction"]), + Event.application_command_autocomplete: EventData(["ApplicationCommandInteraction"]), + Event.button_click: EventData(["MessageInteraction"]), + Event.dropdown: EventData(["MessageInteraction"]), + Event.interaction: EventData(["Interaction"]), + Event.message_interaction: EventData(["MessageInteraction"]), + Event.modal_submit: EventData(["ModalInteraction"]), + Event.message: EventData(["Message"]), + Event.message_edit: EventData(["Message", "Message"]), + Event.message_delete: EventData(["Message"]), + Event.bulk_message_delete: EventData(["List[Message]"]), + Event.raw_message_edit: EventData(["RawMessageUpdateEvent"]), + Event.raw_message_delete: EventData(["RawMessageDeleteEvent"]), + Event.raw_bulk_message_delete: EventData(["RawBulkMessageDeleteEvent"]), + Event.reaction_add: EventData(["Reaction", "Union[Member, User]"]), + Event.reaction_remove: EventData(["Reaction", "Union[Member, User]"]), + Event.reaction_clear: EventData(["Message", "List[Reaction]"]), + Event.reaction_clear_emoji: EventData(["Reaction"]), + Event.raw_reaction_add: EventData(["RawReactionActionEvent"]), + Event.raw_reaction_remove: EventData(["RawReactionActionEvent"]), + Event.raw_reaction_clear: EventData(["RawReactionClearEvent"]), + Event.raw_reaction_clear_emoji: EventData(["RawReactionClearEmojiEvent"]), + Event.typing: EventData( + ["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"] + ), + Event.raw_typing: EventData(["RawTypingEvent"]), + Event.command: EventData(["commands.Context"], bot=True), + Event.command_completion: EventData(["commands.Context"], bot=True), + Event.command_error: EventData(["commands.Context", "commands.CommandError"], bot=True), + Event.slash_command: EventData(["ApplicationCommandInteraction"], bot=True), + Event.slash_command_completion: EventData(["ApplicationCommandInteraction"], bot=True), + Event.slash_command_error: EventData( + ["ApplicationCommandInteraction", "commands.CommandError"], bot=True + ), + Event.user_command: EventData(["ApplicationCommandInteraction"], bot=True), + Event.user_command_completion: EventData(["ApplicationCommandInteraction"], bot=True), + Event.user_command_error: EventData( + ["ApplicationCommandInteraction", "commands.CommandError"], bot=True + ), + Event.message_command: EventData(["ApplicationCommandInteraction"], bot=True), + Event.message_command_completion: EventData(["ApplicationCommandInteraction"], bot=True), + Event.message_command_error: EventData( + ["ApplicationCommandInteraction", "commands.CommandError"], bot=True + ), +} diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 148748d9bd..369ef83b45 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -3,14 +3,14 @@ from __future__ import annotations import types -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, cast +from typing import List, Optional, cast import libcst as cst import libcst.matchers as m from libcst import codemod from disnake import Event +from disnake._event_data import EVENT_DATA, EventData def get_param(func: cst.FunctionDef, name: str) -> cst.Param: @@ -85,14 +85,14 @@ def create_literal(self, event: Event) -> cst.BaseExpression: def create_args_list(self, event_data: EventData) -> cst.BaseExpression: return cst.parse_expression( - f'[{",".join(event_data.args)}]', + f'[{",".join(event_data.type_args)}]', config=self.module.config_for_parsing, ) def generate_wait_for_overload( self, func: cst.FunctionDef, event: Event, event_data: EventData ) -> cst.FunctionDef: - args = event_data.args + args = event_data.type_args new_overload = self.create_empty_overload(func) @@ -133,148 +133,3 @@ def generate_wait_for_overload( ) return new_overload - - -@dataclass -class EventData: - # type names of event arguments, e.g. `("Guild", "User")` - args: Tuple[str, ...] - # whether the event is specific to ext.commands - bot: bool = False - # whether the event can only be used through `@event` and not other listeners - event_only: bool = False - - -EVENT_DATA: Dict[Event, EventData] = { - Event.connect: EventData(()), - Event.disconnect: EventData(()), - # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event - Event.error: EventData((), event_only=True), - Event.gateway_error: EventData((), event_only=True), - Event.ready: EventData(()), - Event.resumed: EventData(()), - Event.shard_connect: EventData(("int",)), - Event.shard_disconnect: EventData(("int",)), - Event.shard_ready: EventData(("int",)), - Event.shard_resumed: EventData(("int",)), - Event.socket_event_type: EventData(("str",)), - Event.socket_raw_receive: EventData(("str",)), - Event.socket_raw_send: EventData(("Union[str, bytes]",)), - Event.guild_channel_create: EventData(("GuildChannel",)), - Event.guild_channel_update: EventData(("GuildChannel", "GuildChannel")), - Event.guild_channel_delete: EventData(("GuildChannel",)), - Event.guild_channel_pins_update: EventData( - ("Union[GuildChannel, Thread]", "Optional[datetime]") - ), - Event.invite_create: EventData(("Invite",)), - Event.invite_delete: EventData(("Invite",)), - Event.private_channel_update: EventData(("GroupChannel", "GroupChannel")), - Event.private_channel_pins_update: EventData(("PrivateChannel", "Optional[datetime]")), - Event.webhooks_update: EventData(("GuildChannel",)), - Event.thread_create: EventData(("Thread",)), - Event.thread_update: EventData(("Thread", "Thread")), - Event.thread_delete: EventData(("Thread",)), - Event.thread_join: EventData(("Thread",)), - Event.thread_remove: EventData(("Thread",)), - Event.thread_member_join: EventData(("ThreadMember",)), - Event.thread_member_remove: EventData(("ThreadMember",)), - Event.raw_thread_member_remove: EventData(("RawThreadMemberRemoveEvent",)), - Event.raw_thread_update: EventData(("Thread",)), - Event.raw_thread_delete: EventData(("RawThreadDeleteEvent",)), - Event.guild_join: EventData(("Guild",)), - Event.guild_remove: EventData(("Guild",)), - Event.guild_update: EventData(("Guild", "Guild")), - Event.guild_available: EventData(("Guild",)), - Event.guild_unavailable: EventData(("Guild",)), - Event.guild_role_create: EventData(("Role",)), - Event.guild_role_delete: EventData(("Role",)), - Event.guild_role_update: EventData(("Role", "Role")), - Event.guild_emojis_update: EventData(("Guild", "Sequence[Emoji]", "Sequence[Emoji]")), - Event.guild_stickers_update: EventData( - ("Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]") - ), - Event.guild_integrations_update: EventData(("Guild",)), - Event.guild_scheduled_event_create: EventData(("GuildScheduledEvent",)), - Event.guild_scheduled_event_update: EventData(("GuildScheduledEvent", "GuildScheduledEvent")), - Event.guild_scheduled_event_delete: EventData(("GuildScheduledEvent",)), - Event.guild_scheduled_event_subscribe: EventData( - ("GuildScheduledEvent", "Union[Member, User]") - ), - Event.guild_scheduled_event_unsubscribe: EventData( - ("GuildScheduledEvent", "Union[Member, User]") - ), - Event.raw_guild_scheduled_event_subscribe: EventData( - ("RawGuildScheduledEventUserActionEvent",) - ), - Event.raw_guild_scheduled_event_unsubscribe: EventData( - ("RawGuildScheduledEventUserActionEvent",) - ), - Event.application_command_permissions_update: EventData( - ("GuildApplicationCommandPermissions",) - ), - Event.automod_action_execution: EventData(("AutoModActionExecution",)), - Event.automod_rule_create: EventData(("AutoModRule",)), - Event.automod_rule_update: EventData(("AutoModRule",)), - Event.automod_rule_delete: EventData(("AutoModRule",)), - Event.audit_log_entry_create: EventData(("AuditLogEntry",)), - Event.integration_create: EventData(("Integration",)), - Event.integration_update: EventData(("Integration",)), - Event.raw_integration_delete: EventData(("RawIntegrationDeleteEvent",)), - Event.member_join: EventData(("Member",)), - Event.member_remove: EventData(("Member",)), - Event.member_update: EventData(("Member", "Member")), - Event.raw_member_remove: EventData(("RawGuildMemberRemoveEvent",)), - Event.raw_member_update: EventData(("Member",)), - Event.member_ban: EventData(("Guild", "Union[User, Member]")), - Event.member_unban: EventData(("Guild", "User")), - Event.presence_update: EventData(("Member", "Member")), - Event.user_update: EventData(("User", "User")), - Event.voice_state_update: EventData(("Member", "VoiceState", "VoiceState")), - Event.stage_instance_create: EventData(("StageInstance",)), - Event.stage_instance_delete: EventData(("StageInstance", "StageInstance")), - Event.stage_instance_update: EventData(("StageInstance",)), - Event.application_command: EventData(("ApplicationCommandInteraction",)), - Event.application_command_autocomplete: EventData(("ApplicationCommandInteraction",)), - Event.button_click: EventData(("MessageInteraction",)), - Event.dropdown: EventData(("MessageInteraction",)), - Event.interaction: EventData(("Interaction",)), - Event.message_interaction: EventData(("MessageInteraction",)), - Event.modal_submit: EventData(("ModalInteraction",)), - Event.message: EventData(("Message",)), - Event.message_edit: EventData(("Message", "Message")), - Event.message_delete: EventData(("Message",)), - Event.bulk_message_delete: EventData(("List[Message]",)), - Event.raw_message_edit: EventData(("RawMessageUpdateEvent",)), - Event.raw_message_delete: EventData(("RawMessageDeleteEvent",)), - Event.raw_bulk_message_delete: EventData(("RawBulkMessageDeleteEvent",)), - Event.reaction_add: EventData(("Reaction", "Union[Member, User]")), - Event.reaction_remove: EventData(("Reaction", "Union[Member, User]")), - Event.reaction_clear: EventData(("Message", "List[Reaction]")), - Event.reaction_clear_emoji: EventData(("Reaction",)), - Event.raw_reaction_add: EventData(("RawReactionActionEvent",)), - Event.raw_reaction_remove: EventData(("RawReactionActionEvent",)), - Event.raw_reaction_clear: EventData(("RawReactionClearEvent",)), - Event.raw_reaction_clear_emoji: EventData(("RawReactionClearEmojiEvent",)), - Event.typing: EventData( - ("Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime") - ), - Event.raw_typing: EventData(("RawTypingEvent",)), - Event.command: EventData(("commands.Context",), bot=True), - Event.command_completion: EventData(("commands.Context",), bot=True), - Event.command_error: EventData(("commands.Context", "commands.CommandError"), bot=True), - Event.slash_command: EventData(("ApplicationCommandInteraction",), bot=True), - Event.slash_command_completion: EventData(("ApplicationCommandInteraction",), bot=True), - Event.slash_command_error: EventData( - ("ApplicationCommandInteraction", "commands.CommandError"), bot=True - ), - Event.user_command: EventData(("ApplicationCommandInteraction",), bot=True), - Event.user_command_completion: EventData(("ApplicationCommandInteraction",), bot=True), - Event.user_command_error: EventData( - ("ApplicationCommandInteraction", "commands.CommandError"), bot=True - ), - Event.message_command: EventData(("ApplicationCommandInteraction",), bot=True), - Event.message_command_completion: EventData(("ApplicationCommandInteraction",), bot=True), - Event.message_command_error: EventData( - ("ApplicationCommandInteraction", "commands.CommandError"), bot=True - ), -} From 624fc369e2daa59747cca1e10dfdad004829c2fc Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 30 Apr 2023 16:27:41 +0200 Subject: [PATCH 10/22] chore: make `EventData` kwarg-only --- disnake/_event_data.py | 244 ++++++++++++++++++++++------------------- 1 file changed, 132 insertions(+), 112 deletions(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index e364e96d71..04fad2e35c 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -17,138 +17,158 @@ class EventData: event_only: bool """Whether the event can only be used through `@event` and not other listeners""" - def __init__(self, type_args: List[str], bot: bool = False, event_only: bool = False) -> None: + def __init__( + self, + *, + type_args: List[str], + bot: bool = False, + event_only: bool = False, + ) -> None: self.type_args = tuple(type_args) self.bot = bot self.event_only = event_only EVENT_DATA: Dict[Event, EventData] = { - Event.connect: EventData([]), - Event.disconnect: EventData([]), + Event.connect: EventData(type_args=[]), + Event.disconnect: EventData(type_args=[]), # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event - Event.error: EventData([], event_only=True), - Event.gateway_error: EventData([], event_only=True), - Event.ready: EventData([]), - Event.resumed: EventData([]), - Event.shard_connect: EventData(["int"]), - Event.shard_disconnect: EventData(["int"]), - Event.shard_ready: EventData(["int"]), - Event.shard_resumed: EventData(["int"]), - Event.socket_event_type: EventData(["str"]), - Event.socket_raw_receive: EventData(["str"]), - Event.socket_raw_send: EventData(["Union[str, bytes]"]), - Event.guild_channel_create: EventData(["GuildChannel"]), - Event.guild_channel_update: EventData(["GuildChannel", "GuildChannel"]), - Event.guild_channel_delete: EventData(["GuildChannel"]), + Event.error: EventData(type_args=[], event_only=True), + Event.gateway_error: EventData(type_args=[], event_only=True), + Event.ready: EventData(type_args=[]), + Event.resumed: EventData(type_args=[]), + Event.shard_connect: EventData(type_args=["int"]), + Event.shard_disconnect: EventData(type_args=["int"]), + Event.shard_ready: EventData(type_args=["int"]), + Event.shard_resumed: EventData(type_args=["int"]), + Event.socket_event_type: EventData(type_args=["str"]), + Event.socket_raw_receive: EventData(type_args=["str"]), + Event.socket_raw_send: EventData(type_args=["Union[str, bytes]"]), + Event.guild_channel_create: EventData(type_args=["GuildChannel"]), + Event.guild_channel_update: EventData(type_args=["GuildChannel", "GuildChannel"]), + Event.guild_channel_delete: EventData(type_args=["GuildChannel"]), Event.guild_channel_pins_update: EventData( - ["Union[GuildChannel, Thread]", "Optional[datetime]"] - ), - Event.invite_create: EventData(["Invite"]), - Event.invite_delete: EventData(["Invite"]), - Event.private_channel_update: EventData(["GroupChannel", "GroupChannel"]), - Event.private_channel_pins_update: EventData(["PrivateChannel", "Optional[datetime]"]), - Event.webhooks_update: EventData(["GuildChannel"]), - Event.thread_create: EventData(["Thread"]), - Event.thread_update: EventData(["Thread", "Thread"]), - Event.thread_delete: EventData(["Thread"]), - Event.thread_join: EventData(["Thread"]), - Event.thread_remove: EventData(["Thread"]), - Event.thread_member_join: EventData(["ThreadMember"]), - Event.thread_member_remove: EventData(["ThreadMember"]), - Event.raw_thread_member_remove: EventData(["RawThreadMemberRemoveEvent"]), - Event.raw_thread_update: EventData(["Thread"]), - Event.raw_thread_delete: EventData(["RawThreadDeleteEvent"]), - Event.guild_join: EventData(["Guild"]), - Event.guild_remove: EventData(["Guild"]), - Event.guild_update: EventData(["Guild", "Guild"]), - Event.guild_available: EventData(["Guild"]), - Event.guild_unavailable: EventData(["Guild"]), - Event.guild_role_create: EventData(["Role"]), - Event.guild_role_delete: EventData(["Role"]), - Event.guild_role_update: EventData(["Role", "Role"]), - Event.guild_emojis_update: EventData(["Guild", "Sequence[Emoji]", "Sequence[Emoji]"]), + type_args=["Union[GuildChannel, Thread]", "Optional[datetime]"] + ), + Event.invite_create: EventData(type_args=["Invite"]), + Event.invite_delete: EventData(type_args=["Invite"]), + Event.private_channel_update: EventData(type_args=["GroupChannel", "GroupChannel"]), + Event.private_channel_pins_update: EventData( + type_args=["PrivateChannel", "Optional[datetime]"] + ), + Event.webhooks_update: EventData(type_args=["GuildChannel"]), + Event.thread_create: EventData(type_args=["Thread"]), + Event.thread_update: EventData(type_args=["Thread", "Thread"]), + Event.thread_delete: EventData(type_args=["Thread"]), + Event.thread_join: EventData(type_args=["Thread"]), + Event.thread_remove: EventData(type_args=["Thread"]), + Event.thread_member_join: EventData(type_args=["ThreadMember"]), + Event.thread_member_remove: EventData(type_args=["ThreadMember"]), + Event.raw_thread_member_remove: EventData(type_args=["RawThreadMemberRemoveEvent"]), + Event.raw_thread_update: EventData(type_args=["Thread"]), + Event.raw_thread_delete: EventData(type_args=["RawThreadDeleteEvent"]), + Event.guild_join: EventData(type_args=["Guild"]), + Event.guild_remove: EventData(type_args=["Guild"]), + Event.guild_update: EventData(type_args=["Guild", "Guild"]), + Event.guild_available: EventData(type_args=["Guild"]), + Event.guild_unavailable: EventData(type_args=["Guild"]), + Event.guild_role_create: EventData(type_args=["Role"]), + Event.guild_role_delete: EventData(type_args=["Role"]), + Event.guild_role_update: EventData(type_args=["Role", "Role"]), + Event.guild_emojis_update: EventData(type_args=["Guild", "Sequence[Emoji]", "Sequence[Emoji]"]), Event.guild_stickers_update: EventData( - ["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"] + type_args=["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"] + ), + Event.guild_integrations_update: EventData(type_args=["Guild"]), + Event.guild_scheduled_event_create: EventData(type_args=["GuildScheduledEvent"]), + Event.guild_scheduled_event_update: EventData( + type_args=["GuildScheduledEvent", "GuildScheduledEvent"] ), - Event.guild_integrations_update: EventData(["Guild"]), - Event.guild_scheduled_event_create: EventData(["GuildScheduledEvent"]), - Event.guild_scheduled_event_update: EventData(["GuildScheduledEvent", "GuildScheduledEvent"]), - Event.guild_scheduled_event_delete: EventData(["GuildScheduledEvent"]), + Event.guild_scheduled_event_delete: EventData(type_args=["GuildScheduledEvent"]), Event.guild_scheduled_event_subscribe: EventData( - ["GuildScheduledEvent", "Union[Member, User]"] + type_args=["GuildScheduledEvent", "Union[Member, User]"] ), Event.guild_scheduled_event_unsubscribe: EventData( - ["GuildScheduledEvent", "Union[Member, User]"] + type_args=["GuildScheduledEvent", "Union[Member, User]"] + ), + Event.raw_guild_scheduled_event_subscribe: EventData( + type_args=["RawGuildScheduledEventUserActionEvent"] ), - Event.raw_guild_scheduled_event_subscribe: EventData(["RawGuildScheduledEventUserActionEvent"]), Event.raw_guild_scheduled_event_unsubscribe: EventData( - ["RawGuildScheduledEventUserActionEvent"] - ), - Event.application_command_permissions_update: EventData(["GuildApplicationCommandPermissions"]), - Event.automod_action_execution: EventData(["AutoModActionExecution"]), - Event.automod_rule_create: EventData(["AutoModRule"]), - Event.automod_rule_update: EventData(["AutoModRule"]), - Event.automod_rule_delete: EventData(["AutoModRule"]), - Event.audit_log_entry_create: EventData(["AuditLogEntry"]), - Event.integration_create: EventData(["Integration"]), - Event.integration_update: EventData(["Integration"]), - Event.raw_integration_delete: EventData(["RawIntegrationDeleteEvent"]), - Event.member_join: EventData(["Member"]), - Event.member_remove: EventData(["Member"]), - Event.member_update: EventData(["Member", "Member"]), - Event.raw_member_remove: EventData(["RawGuildMemberRemoveEvent"]), - Event.raw_member_update: EventData(["Member"]), - Event.member_ban: EventData(["Guild", "Union[User, Member]"]), - Event.member_unban: EventData(["Guild", "User"]), - Event.presence_update: EventData(["Member", "Member"]), - Event.user_update: EventData(["User", "User"]), - Event.voice_state_update: EventData(["Member", "VoiceState", "VoiceState"]), - Event.stage_instance_create: EventData(["StageInstance"]), - Event.stage_instance_delete: EventData(["StageInstance", "StageInstance"]), - Event.stage_instance_update: EventData(["StageInstance"]), - Event.application_command: EventData(["ApplicationCommandInteraction"]), - Event.application_command_autocomplete: EventData(["ApplicationCommandInteraction"]), - Event.button_click: EventData(["MessageInteraction"]), - Event.dropdown: EventData(["MessageInteraction"]), - Event.interaction: EventData(["Interaction"]), - Event.message_interaction: EventData(["MessageInteraction"]), - Event.modal_submit: EventData(["ModalInteraction"]), - Event.message: EventData(["Message"]), - Event.message_edit: EventData(["Message", "Message"]), - Event.message_delete: EventData(["Message"]), - Event.bulk_message_delete: EventData(["List[Message]"]), - Event.raw_message_edit: EventData(["RawMessageUpdateEvent"]), - Event.raw_message_delete: EventData(["RawMessageDeleteEvent"]), - Event.raw_bulk_message_delete: EventData(["RawBulkMessageDeleteEvent"]), - Event.reaction_add: EventData(["Reaction", "Union[Member, User]"]), - Event.reaction_remove: EventData(["Reaction", "Union[Member, User]"]), - Event.reaction_clear: EventData(["Message", "List[Reaction]"]), - Event.reaction_clear_emoji: EventData(["Reaction"]), - Event.raw_reaction_add: EventData(["RawReactionActionEvent"]), - Event.raw_reaction_remove: EventData(["RawReactionActionEvent"]), - Event.raw_reaction_clear: EventData(["RawReactionClearEvent"]), - Event.raw_reaction_clear_emoji: EventData(["RawReactionClearEmojiEvent"]), + type_args=["RawGuildScheduledEventUserActionEvent"] + ), + Event.application_command_permissions_update: EventData( + type_args=["GuildApplicationCommandPermissions"] + ), + Event.automod_action_execution: EventData(type_args=["AutoModActionExecution"]), + Event.automod_rule_create: EventData(type_args=["AutoModRule"]), + Event.automod_rule_update: EventData(type_args=["AutoModRule"]), + Event.automod_rule_delete: EventData(type_args=["AutoModRule"]), + Event.audit_log_entry_create: EventData(type_args=["AuditLogEntry"]), + Event.integration_create: EventData(type_args=["Integration"]), + Event.integration_update: EventData(type_args=["Integration"]), + Event.raw_integration_delete: EventData(type_args=["RawIntegrationDeleteEvent"]), + Event.member_join: EventData(type_args=["Member"]), + Event.member_remove: EventData(type_args=["Member"]), + Event.member_update: EventData(type_args=["Member", "Member"]), + Event.raw_member_remove: EventData(type_args=["RawGuildMemberRemoveEvent"]), + Event.raw_member_update: EventData(type_args=["Member"]), + Event.member_ban: EventData(type_args=["Guild", "Union[User, Member]"]), + Event.member_unban: EventData(type_args=["Guild", "User"]), + Event.presence_update: EventData(type_args=["Member", "Member"]), + Event.user_update: EventData(type_args=["User", "User"]), + Event.voice_state_update: EventData(type_args=["Member", "VoiceState", "VoiceState"]), + Event.stage_instance_create: EventData(type_args=["StageInstance"]), + Event.stage_instance_delete: EventData(type_args=["StageInstance", "StageInstance"]), + Event.stage_instance_update: EventData(type_args=["StageInstance"]), + Event.application_command: EventData(type_args=["ApplicationCommandInteraction"]), + Event.application_command_autocomplete: EventData(type_args=["ApplicationCommandInteraction"]), + Event.button_click: EventData(type_args=["MessageInteraction"]), + Event.dropdown: EventData(type_args=["MessageInteraction"]), + Event.interaction: EventData(type_args=["Interaction"]), + Event.message_interaction: EventData(type_args=["MessageInteraction"]), + Event.modal_submit: EventData(type_args=["ModalInteraction"]), + Event.message: EventData(type_args=["Message"]), + Event.message_edit: EventData(type_args=["Message", "Message"]), + Event.message_delete: EventData(type_args=["Message"]), + Event.bulk_message_delete: EventData(type_args=["List[Message]"]), + Event.raw_message_edit: EventData(type_args=["RawMessageUpdateEvent"]), + Event.raw_message_delete: EventData(type_args=["RawMessageDeleteEvent"]), + Event.raw_bulk_message_delete: EventData(type_args=["RawBulkMessageDeleteEvent"]), + Event.reaction_add: EventData(type_args=["Reaction", "Union[Member, User]"]), + Event.reaction_remove: EventData(type_args=["Reaction", "Union[Member, User]"]), + Event.reaction_clear: EventData(type_args=["Message", "List[Reaction]"]), + Event.reaction_clear_emoji: EventData(type_args=["Reaction"]), + Event.raw_reaction_add: EventData(type_args=["RawReactionActionEvent"]), + Event.raw_reaction_remove: EventData(type_args=["RawReactionActionEvent"]), + Event.raw_reaction_clear: EventData(type_args=["RawReactionClearEvent"]), + Event.raw_reaction_clear_emoji: EventData(type_args=["RawReactionClearEmojiEvent"]), Event.typing: EventData( - ["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"] - ), - Event.raw_typing: EventData(["RawTypingEvent"]), - Event.command: EventData(["commands.Context"], bot=True), - Event.command_completion: EventData(["commands.Context"], bot=True), - Event.command_error: EventData(["commands.Context", "commands.CommandError"], bot=True), - Event.slash_command: EventData(["ApplicationCommandInteraction"], bot=True), - Event.slash_command_completion: EventData(["ApplicationCommandInteraction"], bot=True), + type_args=["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"] + ), + Event.raw_typing: EventData(type_args=["RawTypingEvent"]), + Event.command: EventData(type_args=["commands.Context"], bot=True), + Event.command_completion: EventData(type_args=["commands.Context"], bot=True), + Event.command_error: EventData( + type_args=["commands.Context", "commands.CommandError"], bot=True + ), + Event.slash_command: EventData(type_args=["ApplicationCommandInteraction"], bot=True), + Event.slash_command_completion: EventData( + type_args=["ApplicationCommandInteraction"], bot=True + ), Event.slash_command_error: EventData( - ["ApplicationCommandInteraction", "commands.CommandError"], bot=True + type_args=["ApplicationCommandInteraction", "commands.CommandError"], bot=True ), - Event.user_command: EventData(["ApplicationCommandInteraction"], bot=True), - Event.user_command_completion: EventData(["ApplicationCommandInteraction"], bot=True), + Event.user_command: EventData(type_args=["ApplicationCommandInteraction"], bot=True), + Event.user_command_completion: EventData(type_args=["ApplicationCommandInteraction"], bot=True), Event.user_command_error: EventData( - ["ApplicationCommandInteraction", "commands.CommandError"], bot=True + type_args=["ApplicationCommandInteraction", "commands.CommandError"], bot=True + ), + Event.message_command: EventData(type_args=["ApplicationCommandInteraction"], bot=True), + Event.message_command_completion: EventData( + type_args=["ApplicationCommandInteraction"], bot=True ), - Event.message_command: EventData(["ApplicationCommandInteraction"], bot=True), - Event.message_command_completion: EventData(["ApplicationCommandInteraction"], bot=True), Event.message_command_error: EventData( - ["ApplicationCommandInteraction", "commands.CommandError"], bot=True + type_args=["ApplicationCommandInteraction", "commands.CommandError"], bot=True ), } From 0b5f68bace99df73e07dc1be23bd199f3d74f46c Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 30 Apr 2023 17:52:06 +0200 Subject: [PATCH 11/22] style: put EventData args on separate lines --- disnake/_event_data.py | 418 ++++++++++++++++++++++++++++++----------- 1 file changed, 308 insertions(+), 110 deletions(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index 04fad2e35c..65c4e1cc94 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -30,145 +30,343 @@ def __init__( EVENT_DATA: Dict[Event, EventData] = { - Event.connect: EventData(type_args=[]), - Event.disconnect: EventData(type_args=[]), + Event.connect: EventData( + type_args=[], + ), + Event.disconnect: EventData( + type_args=[], + ), # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event - Event.error: EventData(type_args=[], event_only=True), - Event.gateway_error: EventData(type_args=[], event_only=True), - Event.ready: EventData(type_args=[]), - Event.resumed: EventData(type_args=[]), - Event.shard_connect: EventData(type_args=["int"]), - Event.shard_disconnect: EventData(type_args=["int"]), - Event.shard_ready: EventData(type_args=["int"]), - Event.shard_resumed: EventData(type_args=["int"]), - Event.socket_event_type: EventData(type_args=["str"]), - Event.socket_raw_receive: EventData(type_args=["str"]), - Event.socket_raw_send: EventData(type_args=["Union[str, bytes]"]), - Event.guild_channel_create: EventData(type_args=["GuildChannel"]), - Event.guild_channel_update: EventData(type_args=["GuildChannel", "GuildChannel"]), - Event.guild_channel_delete: EventData(type_args=["GuildChannel"]), + Event.error: EventData( + type_args=[], + event_only=True, + ), + Event.gateway_error: EventData( + type_args=[], + event_only=True, + ), + Event.ready: EventData( + type_args=[], + ), + Event.resumed: EventData( + type_args=[], + ), + Event.shard_connect: EventData( + type_args=["int"], + ), + Event.shard_disconnect: EventData( + type_args=["int"], + ), + Event.shard_ready: EventData( + type_args=["int"], + ), + Event.shard_resumed: EventData( + type_args=["int"], + ), + Event.socket_event_type: EventData( + type_args=["str"], + ), + Event.socket_raw_receive: EventData( + type_args=["str"], + ), + Event.socket_raw_send: EventData( + type_args=["Union[str, bytes]"], + ), + Event.guild_channel_create: EventData( + type_args=["GuildChannel"], + ), + Event.guild_channel_update: EventData( + type_args=["GuildChannel", "GuildChannel"], + ), + Event.guild_channel_delete: EventData( + type_args=["GuildChannel"], + ), Event.guild_channel_pins_update: EventData( - type_args=["Union[GuildChannel, Thread]", "Optional[datetime]"] + type_args=["Union[GuildChannel, Thread]", "Optional[datetime]"], + ), + Event.invite_create: EventData( + type_args=["Invite"], + ), + Event.invite_delete: EventData( + type_args=["Invite"], + ), + Event.private_channel_update: EventData( + type_args=["GroupChannel", "GroupChannel"], ), - Event.invite_create: EventData(type_args=["Invite"]), - Event.invite_delete: EventData(type_args=["Invite"]), - Event.private_channel_update: EventData(type_args=["GroupChannel", "GroupChannel"]), Event.private_channel_pins_update: EventData( - type_args=["PrivateChannel", "Optional[datetime]"] - ), - Event.webhooks_update: EventData(type_args=["GuildChannel"]), - Event.thread_create: EventData(type_args=["Thread"]), - Event.thread_update: EventData(type_args=["Thread", "Thread"]), - Event.thread_delete: EventData(type_args=["Thread"]), - Event.thread_join: EventData(type_args=["Thread"]), - Event.thread_remove: EventData(type_args=["Thread"]), - Event.thread_member_join: EventData(type_args=["ThreadMember"]), - Event.thread_member_remove: EventData(type_args=["ThreadMember"]), - Event.raw_thread_member_remove: EventData(type_args=["RawThreadMemberRemoveEvent"]), - Event.raw_thread_update: EventData(type_args=["Thread"]), - Event.raw_thread_delete: EventData(type_args=["RawThreadDeleteEvent"]), - Event.guild_join: EventData(type_args=["Guild"]), - Event.guild_remove: EventData(type_args=["Guild"]), - Event.guild_update: EventData(type_args=["Guild", "Guild"]), - Event.guild_available: EventData(type_args=["Guild"]), - Event.guild_unavailable: EventData(type_args=["Guild"]), - Event.guild_role_create: EventData(type_args=["Role"]), - Event.guild_role_delete: EventData(type_args=["Role"]), - Event.guild_role_update: EventData(type_args=["Role", "Role"]), - Event.guild_emojis_update: EventData(type_args=["Guild", "Sequence[Emoji]", "Sequence[Emoji]"]), + type_args=["PrivateChannel", "Optional[datetime]"], + ), + Event.webhooks_update: EventData( + type_args=["GuildChannel"], + ), + Event.thread_create: EventData( + type_args=["Thread"], + ), + Event.thread_update: EventData( + type_args=["Thread", "Thread"], + ), + Event.thread_delete: EventData( + type_args=["Thread"], + ), + Event.thread_join: EventData( + type_args=["Thread"], + ), + Event.thread_remove: EventData( + type_args=["Thread"], + ), + Event.thread_member_join: EventData( + type_args=["ThreadMember"], + ), + Event.thread_member_remove: EventData( + type_args=["ThreadMember"], + ), + Event.raw_thread_member_remove: EventData( + type_args=["RawThreadMemberRemoveEvent"], + ), + Event.raw_thread_update: EventData( + type_args=["Thread"], + ), + Event.raw_thread_delete: EventData( + type_args=["RawThreadDeleteEvent"], + ), + Event.guild_join: EventData( + type_args=["Guild"], + ), + Event.guild_remove: EventData( + type_args=["Guild"], + ), + Event.guild_update: EventData( + type_args=["Guild", "Guild"], + ), + Event.guild_available: EventData( + type_args=["Guild"], + ), + Event.guild_unavailable: EventData( + type_args=["Guild"], + ), + Event.guild_role_create: EventData( + type_args=["Role"], + ), + Event.guild_role_delete: EventData( + type_args=["Role"], + ), + Event.guild_role_update: EventData( + type_args=["Role", "Role"], + ), + Event.guild_emojis_update: EventData( + type_args=["Guild", "Sequence[Emoji]", "Sequence[Emoji]"], + ), Event.guild_stickers_update: EventData( - type_args=["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"] + type_args=["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"], + ), + Event.guild_integrations_update: EventData( + type_args=["Guild"], + ), + Event.guild_scheduled_event_create: EventData( + type_args=["GuildScheduledEvent"], ), - Event.guild_integrations_update: EventData(type_args=["Guild"]), - Event.guild_scheduled_event_create: EventData(type_args=["GuildScheduledEvent"]), Event.guild_scheduled_event_update: EventData( - type_args=["GuildScheduledEvent", "GuildScheduledEvent"] + type_args=["GuildScheduledEvent", "GuildScheduledEvent"], + ), + Event.guild_scheduled_event_delete: EventData( + type_args=["GuildScheduledEvent"], ), - Event.guild_scheduled_event_delete: EventData(type_args=["GuildScheduledEvent"]), Event.guild_scheduled_event_subscribe: EventData( - type_args=["GuildScheduledEvent", "Union[Member, User]"] + type_args=["GuildScheduledEvent", "Union[Member, User]"], ), Event.guild_scheduled_event_unsubscribe: EventData( - type_args=["GuildScheduledEvent", "Union[Member, User]"] + type_args=["GuildScheduledEvent", "Union[Member, User]"], ), Event.raw_guild_scheduled_event_subscribe: EventData( - type_args=["RawGuildScheduledEventUserActionEvent"] + type_args=["RawGuildScheduledEventUserActionEvent"], ), Event.raw_guild_scheduled_event_unsubscribe: EventData( - type_args=["RawGuildScheduledEventUserActionEvent"] + type_args=["RawGuildScheduledEventUserActionEvent"], ), Event.application_command_permissions_update: EventData( - type_args=["GuildApplicationCommandPermissions"] - ), - Event.automod_action_execution: EventData(type_args=["AutoModActionExecution"]), - Event.automod_rule_create: EventData(type_args=["AutoModRule"]), - Event.automod_rule_update: EventData(type_args=["AutoModRule"]), - Event.automod_rule_delete: EventData(type_args=["AutoModRule"]), - Event.audit_log_entry_create: EventData(type_args=["AuditLogEntry"]), - Event.integration_create: EventData(type_args=["Integration"]), - Event.integration_update: EventData(type_args=["Integration"]), - Event.raw_integration_delete: EventData(type_args=["RawIntegrationDeleteEvent"]), - Event.member_join: EventData(type_args=["Member"]), - Event.member_remove: EventData(type_args=["Member"]), - Event.member_update: EventData(type_args=["Member", "Member"]), - Event.raw_member_remove: EventData(type_args=["RawGuildMemberRemoveEvent"]), - Event.raw_member_update: EventData(type_args=["Member"]), - Event.member_ban: EventData(type_args=["Guild", "Union[User, Member]"]), - Event.member_unban: EventData(type_args=["Guild", "User"]), - Event.presence_update: EventData(type_args=["Member", "Member"]), - Event.user_update: EventData(type_args=["User", "User"]), - Event.voice_state_update: EventData(type_args=["Member", "VoiceState", "VoiceState"]), - Event.stage_instance_create: EventData(type_args=["StageInstance"]), - Event.stage_instance_delete: EventData(type_args=["StageInstance", "StageInstance"]), - Event.stage_instance_update: EventData(type_args=["StageInstance"]), - Event.application_command: EventData(type_args=["ApplicationCommandInteraction"]), - Event.application_command_autocomplete: EventData(type_args=["ApplicationCommandInteraction"]), - Event.button_click: EventData(type_args=["MessageInteraction"]), - Event.dropdown: EventData(type_args=["MessageInteraction"]), - Event.interaction: EventData(type_args=["Interaction"]), - Event.message_interaction: EventData(type_args=["MessageInteraction"]), - Event.modal_submit: EventData(type_args=["ModalInteraction"]), - Event.message: EventData(type_args=["Message"]), - Event.message_edit: EventData(type_args=["Message", "Message"]), - Event.message_delete: EventData(type_args=["Message"]), - Event.bulk_message_delete: EventData(type_args=["List[Message]"]), - Event.raw_message_edit: EventData(type_args=["RawMessageUpdateEvent"]), - Event.raw_message_delete: EventData(type_args=["RawMessageDeleteEvent"]), - Event.raw_bulk_message_delete: EventData(type_args=["RawBulkMessageDeleteEvent"]), - Event.reaction_add: EventData(type_args=["Reaction", "Union[Member, User]"]), - Event.reaction_remove: EventData(type_args=["Reaction", "Union[Member, User]"]), - Event.reaction_clear: EventData(type_args=["Message", "List[Reaction]"]), - Event.reaction_clear_emoji: EventData(type_args=["Reaction"]), - Event.raw_reaction_add: EventData(type_args=["RawReactionActionEvent"]), - Event.raw_reaction_remove: EventData(type_args=["RawReactionActionEvent"]), - Event.raw_reaction_clear: EventData(type_args=["RawReactionClearEvent"]), - Event.raw_reaction_clear_emoji: EventData(type_args=["RawReactionClearEmojiEvent"]), + type_args=["GuildApplicationCommandPermissions"], + ), + Event.automod_action_execution: EventData( + type_args=["AutoModActionExecution"], + ), + Event.automod_rule_create: EventData( + type_args=["AutoModRule"], + ), + Event.automod_rule_update: EventData( + type_args=["AutoModRule"], + ), + Event.automod_rule_delete: EventData( + type_args=["AutoModRule"], + ), + Event.audit_log_entry_create: EventData( + type_args=["AuditLogEntry"], + ), + Event.integration_create: EventData( + type_args=["Integration"], + ), + Event.integration_update: EventData( + type_args=["Integration"], + ), + Event.raw_integration_delete: EventData( + type_args=["RawIntegrationDeleteEvent"], + ), + Event.member_join: EventData( + type_args=["Member"], + ), + Event.member_remove: EventData( + type_args=["Member"], + ), + Event.member_update: EventData( + type_args=["Member", "Member"], + ), + Event.raw_member_remove: EventData( + type_args=["RawGuildMemberRemoveEvent"], + ), + Event.raw_member_update: EventData( + type_args=["Member"], + ), + Event.member_ban: EventData( + type_args=["Guild", "Union[User, Member]"], + ), + Event.member_unban: EventData( + type_args=["Guild", "User"], + ), + Event.presence_update: EventData( + type_args=["Member", "Member"], + ), + Event.user_update: EventData( + type_args=["User", "User"], + ), + Event.voice_state_update: EventData( + type_args=["Member", "VoiceState", "VoiceState"], + ), + Event.stage_instance_create: EventData( + type_args=["StageInstance"], + ), + Event.stage_instance_delete: EventData( + type_args=["StageInstance", "StageInstance"], + ), + Event.stage_instance_update: EventData( + type_args=["StageInstance"], + ), + Event.application_command: EventData( + type_args=["ApplicationCommandInteraction"], + ), + Event.application_command_autocomplete: EventData( + type_args=["ApplicationCommandInteraction"], + ), + Event.button_click: EventData( + type_args=["MessageInteraction"], + ), + Event.dropdown: EventData( + type_args=["MessageInteraction"], + ), + Event.interaction: EventData( + type_args=["Interaction"], + ), + Event.message_interaction: EventData( + type_args=["MessageInteraction"], + ), + Event.modal_submit: EventData( + type_args=["ModalInteraction"], + ), + Event.message: EventData( + type_args=["Message"], + ), + Event.message_edit: EventData( + type_args=["Message", "Message"], + ), + Event.message_delete: EventData( + type_args=["Message"], + ), + Event.bulk_message_delete: EventData( + type_args=["List[Message]"], + ), + Event.raw_message_edit: EventData( + type_args=["RawMessageUpdateEvent"], + ), + Event.raw_message_delete: EventData( + type_args=["RawMessageDeleteEvent"], + ), + Event.raw_bulk_message_delete: EventData( + type_args=["RawBulkMessageDeleteEvent"], + ), + Event.reaction_add: EventData( + type_args=["Reaction", "Union[Member, User]"], + ), + Event.reaction_remove: EventData( + type_args=["Reaction", "Union[Member, User]"], + ), + Event.reaction_clear: EventData( + type_args=["Message", "List[Reaction]"], + ), + Event.reaction_clear_emoji: EventData( + type_args=["Reaction"], + ), + Event.raw_reaction_add: EventData( + type_args=["RawReactionActionEvent"], + ), + Event.raw_reaction_remove: EventData( + type_args=["RawReactionActionEvent"], + ), + Event.raw_reaction_clear: EventData( + type_args=["RawReactionClearEvent"], + ), + Event.raw_reaction_clear_emoji: EventData( + type_args=["RawReactionClearEmojiEvent"], + ), Event.typing: EventData( - type_args=["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"] + type_args=["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"], + ), + Event.raw_typing: EventData( + type_args=["RawTypingEvent"], + ), + Event.command: EventData( + type_args=["commands.Context"], + bot=True, + ), + Event.command_completion: EventData( + type_args=["commands.Context"], + bot=True, ), - Event.raw_typing: EventData(type_args=["RawTypingEvent"]), - Event.command: EventData(type_args=["commands.Context"], bot=True), - Event.command_completion: EventData(type_args=["commands.Context"], bot=True), Event.command_error: EventData( - type_args=["commands.Context", "commands.CommandError"], bot=True + type_args=["commands.Context", "commands.CommandError"], + bot=True, + ), + Event.slash_command: EventData( + type_args=["ApplicationCommandInteraction"], + bot=True, ), - Event.slash_command: EventData(type_args=["ApplicationCommandInteraction"], bot=True), Event.slash_command_completion: EventData( - type_args=["ApplicationCommandInteraction"], bot=True + type_args=["ApplicationCommandInteraction"], + bot=True, ), Event.slash_command_error: EventData( - type_args=["ApplicationCommandInteraction", "commands.CommandError"], bot=True + type_args=["ApplicationCommandInteraction", "commands.CommandError"], + bot=True, + ), + Event.user_command: EventData( + type_args=["ApplicationCommandInteraction"], + bot=True, + ), + Event.user_command_completion: EventData( + type_args=["ApplicationCommandInteraction"], + bot=True, ), - Event.user_command: EventData(type_args=["ApplicationCommandInteraction"], bot=True), - Event.user_command_completion: EventData(type_args=["ApplicationCommandInteraction"], bot=True), Event.user_command_error: EventData( - type_args=["ApplicationCommandInteraction", "commands.CommandError"], bot=True + type_args=["ApplicationCommandInteraction", "commands.CommandError"], + bot=True, + ), + Event.message_command: EventData( + type_args=["ApplicationCommandInteraction"], + bot=True, ), - Event.message_command: EventData(type_args=["ApplicationCommandInteraction"], bot=True), Event.message_command_completion: EventData( - type_args=["ApplicationCommandInteraction"], bot=True + type_args=["ApplicationCommandInteraction"], + bot=True, ), Event.message_command_error: EventData( - type_args=["ApplicationCommandInteraction", "commands.CommandError"], bot=True + type_args=["ApplicationCommandInteraction", "commands.CommandError"], + bot=True, ), } From 58757068dae3bf96ebef98ea215c4a36c93e2ff5 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 30 Apr 2023 17:53:19 +0200 Subject: [PATCH 12/22] chore: move annotations into `EventData.__init__` --- disnake/_event_data.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index 65c4e1cc94..0451c240bc 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -8,15 +8,6 @@ class EventData: - type_args: Tuple[str, ...] - """Type names of event arguments, e.g. `("Guild", "User")`""" - - bot: bool - """Whether the event is specific to ext.commands""" - - event_only: bool - """Whether the event can only be used through `@event` and not other listeners""" - def __init__( self, *, @@ -24,9 +15,14 @@ def __init__( bot: bool = False, event_only: bool = False, ) -> None: - self.type_args = tuple(type_args) - self.bot = bot - self.event_only = event_only + self.type_args: Tuple[str, ...] = tuple(type_args) + """Type names of event arguments, e.g. `("Guild", "User")`""" + + self.bot: bool = bot + """Whether the event is specific to ext.commands""" + + self.event_only: bool = event_only + """Whether the event can only be used through `@event` and not other listeners""" EVENT_DATA: Dict[Event, EventData] = { From 61f85fbfebc119b3704538523f4b08283e981cfb Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 13 Jun 2023 15:10:56 +0200 Subject: [PATCH 13/22] docs: add changelog entry --- changelog/1017.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1017.feature.rst diff --git a/changelog/1017.feature.rst b/changelog/1017.feature.rst new file mode 100644 index 0000000000..e0c55f0553 --- /dev/null +++ b/changelog/1017.feature.rst @@ -0,0 +1 @@ +Add typing overloads to :meth:`Client.wait_for` for every :class:`Event` value, allowing for correct typing of the ``check`` parameter and the return value. From 602b315420cf4b7de431877a43cbcd34405207ad Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 13 Jun 2023 15:27:09 +0200 Subject: [PATCH 14/22] chore: rename `type_args` field to `arg_types` --- disnake/_event_data.py | 220 +++++++++++++++---------------- scripts/codemods/typed_events.py | 4 +- 2 files changed, 112 insertions(+), 112 deletions(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index 0451c240bc..b62444e932 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -11,11 +11,11 @@ class EventData: def __init__( self, *, - type_args: List[str], + arg_types: List[str], bot: bool = False, event_only: bool = False, ) -> None: - self.type_args: Tuple[str, ...] = tuple(type_args) + self.arg_types: Tuple[str, ...] = tuple(arg_types) """Type names of event arguments, e.g. `("Guild", "User")`""" self.bot: bool = bot @@ -27,342 +27,342 @@ def __init__( EVENT_DATA: Dict[Event, EventData] = { Event.connect: EventData( - type_args=[], + arg_types=[], ), Event.disconnect: EventData( - type_args=[], + arg_types=[], ), # FIXME: figure out how to specify varargs for these two if we ever add overloads for @event Event.error: EventData( - type_args=[], + arg_types=[], event_only=True, ), Event.gateway_error: EventData( - type_args=[], + arg_types=[], event_only=True, ), Event.ready: EventData( - type_args=[], + arg_types=[], ), Event.resumed: EventData( - type_args=[], + arg_types=[], ), Event.shard_connect: EventData( - type_args=["int"], + arg_types=["int"], ), Event.shard_disconnect: EventData( - type_args=["int"], + arg_types=["int"], ), Event.shard_ready: EventData( - type_args=["int"], + arg_types=["int"], ), Event.shard_resumed: EventData( - type_args=["int"], + arg_types=["int"], ), Event.socket_event_type: EventData( - type_args=["str"], + arg_types=["str"], ), Event.socket_raw_receive: EventData( - type_args=["str"], + arg_types=["str"], ), Event.socket_raw_send: EventData( - type_args=["Union[str, bytes]"], + arg_types=["Union[str, bytes]"], ), Event.guild_channel_create: EventData( - type_args=["GuildChannel"], + arg_types=["GuildChannel"], ), Event.guild_channel_update: EventData( - type_args=["GuildChannel", "GuildChannel"], + arg_types=["GuildChannel", "GuildChannel"], ), Event.guild_channel_delete: EventData( - type_args=["GuildChannel"], + arg_types=["GuildChannel"], ), Event.guild_channel_pins_update: EventData( - type_args=["Union[GuildChannel, Thread]", "Optional[datetime]"], + arg_types=["Union[GuildChannel, Thread]", "Optional[datetime]"], ), Event.invite_create: EventData( - type_args=["Invite"], + arg_types=["Invite"], ), Event.invite_delete: EventData( - type_args=["Invite"], + arg_types=["Invite"], ), Event.private_channel_update: EventData( - type_args=["GroupChannel", "GroupChannel"], + arg_types=["GroupChannel", "GroupChannel"], ), Event.private_channel_pins_update: EventData( - type_args=["PrivateChannel", "Optional[datetime]"], + arg_types=["PrivateChannel", "Optional[datetime]"], ), Event.webhooks_update: EventData( - type_args=["GuildChannel"], + arg_types=["GuildChannel"], ), Event.thread_create: EventData( - type_args=["Thread"], + arg_types=["Thread"], ), Event.thread_update: EventData( - type_args=["Thread", "Thread"], + arg_types=["Thread", "Thread"], ), Event.thread_delete: EventData( - type_args=["Thread"], + arg_types=["Thread"], ), Event.thread_join: EventData( - type_args=["Thread"], + arg_types=["Thread"], ), Event.thread_remove: EventData( - type_args=["Thread"], + arg_types=["Thread"], ), Event.thread_member_join: EventData( - type_args=["ThreadMember"], + arg_types=["ThreadMember"], ), Event.thread_member_remove: EventData( - type_args=["ThreadMember"], + arg_types=["ThreadMember"], ), Event.raw_thread_member_remove: EventData( - type_args=["RawThreadMemberRemoveEvent"], + arg_types=["RawThreadMemberRemoveEvent"], ), Event.raw_thread_update: EventData( - type_args=["Thread"], + arg_types=["Thread"], ), Event.raw_thread_delete: EventData( - type_args=["RawThreadDeleteEvent"], + arg_types=["RawThreadDeleteEvent"], ), Event.guild_join: EventData( - type_args=["Guild"], + arg_types=["Guild"], ), Event.guild_remove: EventData( - type_args=["Guild"], + arg_types=["Guild"], ), Event.guild_update: EventData( - type_args=["Guild", "Guild"], + arg_types=["Guild", "Guild"], ), Event.guild_available: EventData( - type_args=["Guild"], + arg_types=["Guild"], ), Event.guild_unavailable: EventData( - type_args=["Guild"], + arg_types=["Guild"], ), Event.guild_role_create: EventData( - type_args=["Role"], + arg_types=["Role"], ), Event.guild_role_delete: EventData( - type_args=["Role"], + arg_types=["Role"], ), Event.guild_role_update: EventData( - type_args=["Role", "Role"], + arg_types=["Role", "Role"], ), Event.guild_emojis_update: EventData( - type_args=["Guild", "Sequence[Emoji]", "Sequence[Emoji]"], + arg_types=["Guild", "Sequence[Emoji]", "Sequence[Emoji]"], ), Event.guild_stickers_update: EventData( - type_args=["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"], + arg_types=["Guild", "Sequence[GuildSticker]", "Sequence[GuildSticker]"], ), Event.guild_integrations_update: EventData( - type_args=["Guild"], + arg_types=["Guild"], ), Event.guild_scheduled_event_create: EventData( - type_args=["GuildScheduledEvent"], + arg_types=["GuildScheduledEvent"], ), Event.guild_scheduled_event_update: EventData( - type_args=["GuildScheduledEvent", "GuildScheduledEvent"], + arg_types=["GuildScheduledEvent", "GuildScheduledEvent"], ), Event.guild_scheduled_event_delete: EventData( - type_args=["GuildScheduledEvent"], + arg_types=["GuildScheduledEvent"], ), Event.guild_scheduled_event_subscribe: EventData( - type_args=["GuildScheduledEvent", "Union[Member, User]"], + arg_types=["GuildScheduledEvent", "Union[Member, User]"], ), Event.guild_scheduled_event_unsubscribe: EventData( - type_args=["GuildScheduledEvent", "Union[Member, User]"], + arg_types=["GuildScheduledEvent", "Union[Member, User]"], ), Event.raw_guild_scheduled_event_subscribe: EventData( - type_args=["RawGuildScheduledEventUserActionEvent"], + arg_types=["RawGuildScheduledEventUserActionEvent"], ), Event.raw_guild_scheduled_event_unsubscribe: EventData( - type_args=["RawGuildScheduledEventUserActionEvent"], + arg_types=["RawGuildScheduledEventUserActionEvent"], ), Event.application_command_permissions_update: EventData( - type_args=["GuildApplicationCommandPermissions"], + arg_types=["GuildApplicationCommandPermissions"], ), Event.automod_action_execution: EventData( - type_args=["AutoModActionExecution"], + arg_types=["AutoModActionExecution"], ), Event.automod_rule_create: EventData( - type_args=["AutoModRule"], + arg_types=["AutoModRule"], ), Event.automod_rule_update: EventData( - type_args=["AutoModRule"], + arg_types=["AutoModRule"], ), Event.automod_rule_delete: EventData( - type_args=["AutoModRule"], + arg_types=["AutoModRule"], ), Event.audit_log_entry_create: EventData( - type_args=["AuditLogEntry"], + arg_types=["AuditLogEntry"], ), Event.integration_create: EventData( - type_args=["Integration"], + arg_types=["Integration"], ), Event.integration_update: EventData( - type_args=["Integration"], + arg_types=["Integration"], ), Event.raw_integration_delete: EventData( - type_args=["RawIntegrationDeleteEvent"], + arg_types=["RawIntegrationDeleteEvent"], ), Event.member_join: EventData( - type_args=["Member"], + arg_types=["Member"], ), Event.member_remove: EventData( - type_args=["Member"], + arg_types=["Member"], ), Event.member_update: EventData( - type_args=["Member", "Member"], + arg_types=["Member", "Member"], ), Event.raw_member_remove: EventData( - type_args=["RawGuildMemberRemoveEvent"], + arg_types=["RawGuildMemberRemoveEvent"], ), Event.raw_member_update: EventData( - type_args=["Member"], + arg_types=["Member"], ), Event.member_ban: EventData( - type_args=["Guild", "Union[User, Member]"], + arg_types=["Guild", "Union[User, Member]"], ), Event.member_unban: EventData( - type_args=["Guild", "User"], + arg_types=["Guild", "User"], ), Event.presence_update: EventData( - type_args=["Member", "Member"], + arg_types=["Member", "Member"], ), Event.user_update: EventData( - type_args=["User", "User"], + arg_types=["User", "User"], ), Event.voice_state_update: EventData( - type_args=["Member", "VoiceState", "VoiceState"], + arg_types=["Member", "VoiceState", "VoiceState"], ), Event.stage_instance_create: EventData( - type_args=["StageInstance"], + arg_types=["StageInstance"], ), Event.stage_instance_delete: EventData( - type_args=["StageInstance", "StageInstance"], + arg_types=["StageInstance", "StageInstance"], ), Event.stage_instance_update: EventData( - type_args=["StageInstance"], + arg_types=["StageInstance"], ), Event.application_command: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], ), Event.application_command_autocomplete: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], ), Event.button_click: EventData( - type_args=["MessageInteraction"], + arg_types=["MessageInteraction"], ), Event.dropdown: EventData( - type_args=["MessageInteraction"], + arg_types=["MessageInteraction"], ), Event.interaction: EventData( - type_args=["Interaction"], + arg_types=["Interaction"], ), Event.message_interaction: EventData( - type_args=["MessageInteraction"], + arg_types=["MessageInteraction"], ), Event.modal_submit: EventData( - type_args=["ModalInteraction"], + arg_types=["ModalInteraction"], ), Event.message: EventData( - type_args=["Message"], + arg_types=["Message"], ), Event.message_edit: EventData( - type_args=["Message", "Message"], + arg_types=["Message", "Message"], ), Event.message_delete: EventData( - type_args=["Message"], + arg_types=["Message"], ), Event.bulk_message_delete: EventData( - type_args=["List[Message]"], + arg_types=["List[Message]"], ), Event.raw_message_edit: EventData( - type_args=["RawMessageUpdateEvent"], + arg_types=["RawMessageUpdateEvent"], ), Event.raw_message_delete: EventData( - type_args=["RawMessageDeleteEvent"], + arg_types=["RawMessageDeleteEvent"], ), Event.raw_bulk_message_delete: EventData( - type_args=["RawBulkMessageDeleteEvent"], + arg_types=["RawBulkMessageDeleteEvent"], ), Event.reaction_add: EventData( - type_args=["Reaction", "Union[Member, User]"], + arg_types=["Reaction", "Union[Member, User]"], ), Event.reaction_remove: EventData( - type_args=["Reaction", "Union[Member, User]"], + arg_types=["Reaction", "Union[Member, User]"], ), Event.reaction_clear: EventData( - type_args=["Message", "List[Reaction]"], + arg_types=["Message", "List[Reaction]"], ), Event.reaction_clear_emoji: EventData( - type_args=["Reaction"], + arg_types=["Reaction"], ), Event.raw_reaction_add: EventData( - type_args=["RawReactionActionEvent"], + arg_types=["RawReactionActionEvent"], ), Event.raw_reaction_remove: EventData( - type_args=["RawReactionActionEvent"], + arg_types=["RawReactionActionEvent"], ), Event.raw_reaction_clear: EventData( - type_args=["RawReactionClearEvent"], + arg_types=["RawReactionClearEvent"], ), Event.raw_reaction_clear_emoji: EventData( - type_args=["RawReactionClearEmojiEvent"], + arg_types=["RawReactionClearEmojiEvent"], ), Event.typing: EventData( - type_args=["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"], + arg_types=["Union[Messageable, ForumChannel]", "Union[User, Member]", "datetime"], ), Event.raw_typing: EventData( - type_args=["RawTypingEvent"], + arg_types=["RawTypingEvent"], ), Event.command: EventData( - type_args=["commands.Context"], + arg_types=["commands.Context"], bot=True, ), Event.command_completion: EventData( - type_args=["commands.Context"], + arg_types=["commands.Context"], bot=True, ), Event.command_error: EventData( - type_args=["commands.Context", "commands.CommandError"], + arg_types=["commands.Context", "commands.CommandError"], bot=True, ), Event.slash_command: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], bot=True, ), Event.slash_command_completion: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], bot=True, ), Event.slash_command_error: EventData( - type_args=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction", "commands.CommandError"], bot=True, ), Event.user_command: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], bot=True, ), Event.user_command_completion: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], bot=True, ), Event.user_command_error: EventData( - type_args=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction", "commands.CommandError"], bot=True, ), Event.message_command: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], bot=True, ), Event.message_command_completion: EventData( - type_args=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction"], bot=True, ), Event.message_command_error: EventData( - type_args=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction", "commands.CommandError"], bot=True, ), } diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 369ef83b45..d617c99090 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -85,14 +85,14 @@ def create_literal(self, event: Event) -> cst.BaseExpression: def create_args_list(self, event_data: EventData) -> cst.BaseExpression: return cst.parse_expression( - f'[{",".join(event_data.type_args)}]', + f'[{",".join(event_data.arg_types)}]', config=self.module.config_for_parsing, ) def generate_wait_for_overload( self, func: cst.FunctionDef, event: Event, event_data: EventData ) -> cst.FunctionDef: - args = event_data.type_args + args = event_data.arg_types new_overload = self.create_empty_overload(func) From f80d199e1742df859cb4f0dd24b4376cacd8ec1d Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 1 Aug 2023 20:52:05 +0200 Subject: [PATCH 15/22] fix: remove duplicate key from pyproject --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e54b23b2c7..f7f2b98b62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,7 +217,6 @@ ignore = [ ] "scripts/*.py" = ["S101"] # use of assert is okay in scripts "tests/*.py" = ["S101"] # use of assert is okay in test files -"scripts/*.py" = ["S101"] # use of assert is okay in codemods # we are not using noqa in the example files themselves "examples/*.py" = [ "B008", # do not perform function calls in argument defaults, this is how most commands work From f92a98425bb792d20bdac9bd04c9602687d86a08 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 1 Aug 2023 22:49:15 +0200 Subject: [PATCH 16/22] fix: add codemod to new combined module --- scripts/codemods/combined.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/codemods/combined.py b/scripts/codemods/combined.py index a3cd957d01..e6b5c90ffa 100644 --- a/scripts/codemods/combined.py +++ b/scripts/codemods/combined.py @@ -5,11 +5,12 @@ import libcst as cst from libcst import codemod -from . import overloads_no_missing, typed_flags, typed_permissions +from . import overloads_no_missing, typed_events, typed_flags, typed_permissions from .base import NoMetadataWrapperMixin CODEMODS = [ overloads_no_missing.EllipsisOverloads, + typed_events.EventTypings, typed_flags.FlagTypings, typed_permissions.PermissionTypings, ] From 600af6038894098ce77058c368035a1ca382bd80 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 21 Sep 2023 00:16:27 +0200 Subject: [PATCH 17/22] fix(typing): add generic bot parameter to interaction types --- disnake/_event_data.py | 32 ++++++------ disnake/client.py | 87 +++++++++++++++++--------------- scripts/codemods/typed_events.py | 3 +- tests/test_events.py | 2 +- 4 files changed, 65 insertions(+), 59 deletions(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index b62444e932..be7946a632 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -246,25 +246,25 @@ def __init__( arg_types=["StageInstance"], ), Event.application_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[Self]"], ), Event.application_command_autocomplete: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[Self]"], ), Event.button_click: EventData( - arg_types=["MessageInteraction"], + arg_types=["MessageInteraction[Self]"], ), Event.dropdown: EventData( - arg_types=["MessageInteraction"], + arg_types=["MessageInteraction[Self]"], ), Event.interaction: EventData( - arg_types=["Interaction"], + arg_types=["Interaction[Self]"], ), Event.message_interaction: EventData( - arg_types=["MessageInteraction"], + arg_types=["MessageInteraction[Self]"], ), Event.modal_submit: EventData( - arg_types=["ModalInteraction"], + arg_types=["ModalInteraction[Self]"], ), Event.message: EventData( arg_types=["Message"], @@ -330,39 +330,39 @@ def __init__( bot=True, ), Event.slash_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.slash_command_completion: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.slash_command_error: EventData( - arg_types=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], bot=True, ), Event.user_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.user_command_completion: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.user_command_error: EventData( - arg_types=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], bot=True, ), Event.message_command: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.message_command_completion: EventData( - arg_types=["ApplicationCommandInteraction"], + arg_types=["ApplicationCommandInteraction[AnyBotT]"], bot=True, ), Event.message_command_error: EventData( - arg_types=["ApplicationCommandInteraction", "commands.CommandError"], + arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], bot=True, ), } diff --git a/disnake/client.py b/disnake/client.py index ba8f49ebd7..63f1c65f69 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -79,6 +79,8 @@ from .widget import Widget if TYPE_CHECKING: + from typing_extensions import Self + from disnake.ext import commands from .abc import GuildChannel, Messageable, PrivateChannel, Snowflake, SnowflakeTime @@ -125,6 +127,9 @@ commands.InteractionBot, commands.AutoShardedInteractionBot, ] + # we can't use `typing.Self` when the `self: AnyBot` parameter is annotated, + # so go back to the old way of using a TypeVar for those overloads + AnyBotT = TypeVar("AnyBotT", bound=AnyBot) __all__ = ( @@ -2500,9 +2505,9 @@ def wait_for( self, event: Literal[Event.application_command, "application_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[Self]]: ... @overload @@ -2511,9 +2516,9 @@ def wait_for( self, event: Literal[Event.application_command_autocomplete, "application_command_autocomplete"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[Self]]: ... @overload @@ -2522,9 +2527,9 @@ def wait_for( self, event: Literal[Event.button_click, "button_click"], *, - check: Optional[Callable[[MessageInteraction], bool]] = None, + check: Optional[Callable[[MessageInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, MessageInteraction]: + ) -> Coroutine[Any, Any, MessageInteraction[Self]]: ... @overload @@ -2533,9 +2538,9 @@ def wait_for( self, event: Literal[Event.dropdown, "dropdown"], *, - check: Optional[Callable[[MessageInteraction], bool]] = None, + check: Optional[Callable[[MessageInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, MessageInteraction]: + ) -> Coroutine[Any, Any, MessageInteraction[Self]]: ... @overload @@ -2544,9 +2549,9 @@ def wait_for( self, event: Literal[Event.interaction, "interaction"], *, - check: Optional[Callable[[Interaction], bool]] = None, + check: Optional[Callable[[Interaction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Interaction]: + ) -> Coroutine[Any, Any, Interaction[Self]]: ... @overload @@ -2555,9 +2560,9 @@ def wait_for( self, event: Literal[Event.message_interaction, "message_interaction"], *, - check: Optional[Callable[[MessageInteraction], bool]] = None, + check: Optional[Callable[[MessageInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, MessageInteraction]: + ) -> Coroutine[Any, Any, MessageInteraction[Self]]: ... @overload @@ -2566,9 +2571,9 @@ def wait_for( self, event: Literal[Event.modal_submit, "modal_submit"], *, - check: Optional[Callable[[ModalInteraction], bool]] = None, + check: Optional[Callable[[ModalInteraction[Self]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ModalInteraction]: + ) -> Coroutine[Any, Any, ModalInteraction[Self]]: ... @overload @@ -2798,106 +2803,106 @@ def wait_for( @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.slash_command, "slash_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.slash_command_completion, "slash_command_completion"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.slash_command_error, "slash_command_error"], *, check: Optional[ - Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool] ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.user_command, "user_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.user_command_completion, "user_command_completion"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.user_command_error, "user_command_error"], *, check: Optional[ - Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool] ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.message_command, "message_command"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.message_command_completion, "message_command_completion"], *, - check: Optional[Callable[[ApplicationCommandInteraction], bool]] = None, + check: Optional[Callable[[ApplicationCommandInteraction[AnyBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, ApplicationCommandInteraction]: + ) -> Coroutine[Any, Any, ApplicationCommandInteraction[AnyBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyBotT, event: Literal[Event.message_command_error, "message_command_error"], *, check: Optional[ - Callable[[ApplicationCommandInteraction, commands.CommandError], bool] + Callable[[ApplicationCommandInteraction[AnyBotT], commands.CommandError], bool] ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[ApplicationCommandInteraction[AnyBotT], commands.CommandError]]: ... # fallback for custom events diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index d617c99090..68a7b4908c 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -127,9 +127,10 @@ def generate_wait_for_overload( # set `self` annotation as a workaround for overloads in subclasses if event_data.bot: + self_type = "AnyBotT" if "AnyBotT" in new_annotation_str else "AnyBot" # fun. new_overload = new_overload.with_deep_changes( get_param(new_overload, "self"), - annotation=cst.Annotation(cst.Name("AnyBot")), + annotation=cst.Annotation(cst.Name(self_type)), ) return new_overload diff --git a/tests/test_events.py b/tests/test_events.py index a25f7dd771..fa4952195e 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -72,7 +72,7 @@ def _test_typing_wait_for(client: disnake.Client, bot: commands.Bot) -> None: _ = client.wait_for(Event.slash_command_error) # type: ignore # this should error _ = assert_type( bot.wait_for(Event.slash_command), - Coroutine[Any, Any, disnake.ApplicationCommandInteraction], + Coroutine[Any, Any, disnake.ApplicationCommandInteraction[commands.Bot]], ) From 496902a77483382b1478c0a8beafec78d5b3cdcf Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 21 Sep 2023 00:36:06 +0200 Subject: [PATCH 18/22] feat(typing): add generic bot parameter to `Context` events --- disnake/_event_data.py | 43 ++++++++++++++++---------------- disnake/client.py | 25 +++++++++++-------- disnake/interactions/base.py | 3 --- scripts/codemods/typed_events.py | 5 ++-- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index be7946a632..b25c46f3de 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from .enums import Event @@ -12,17 +12,17 @@ def __init__( self, *, arg_types: List[str], - bot: bool = False, + self_type: Optional[str] = None, event_only: bool = False, ) -> None: self.arg_types: Tuple[str, ...] = tuple(arg_types) - """Type names of event arguments, e.g. `("Guild", "User")`""" + """Type names of event arguments, e.g. `("Guild", "User")`.""" - self.bot: bool = bot - """Whether the event is specific to ext.commands""" + self.self_type: Optional[str] = self_type + """The annotation for the `self` parameter, used for bot-only events.""" self.event_only: bool = event_only - """Whether the event can only be used through `@event` and not other listeners""" + """Whether the event can only be used through `@event`, and not with listeners.""" EVENT_DATA: Dict[Event, EventData] = { @@ -317,52 +317,53 @@ def __init__( Event.raw_typing: EventData( arg_types=["RawTypingEvent"], ), + # bot-only: Event.command: EventData( - arg_types=["commands.Context"], - bot=True, + arg_types=["commands.Context[AnyPrefixBotT]"], + self_type="AnyPrefixBotT", ), Event.command_completion: EventData( - arg_types=["commands.Context"], - bot=True, + arg_types=["commands.Context[AnyPrefixBotT]"], + self_type="AnyPrefixBotT", ), Event.command_error: EventData( - arg_types=["commands.Context", "commands.CommandError"], - bot=True, + arg_types=["commands.Context[AnyPrefixBotT]", "commands.CommandError"], + self_type="AnyPrefixBotT", ), Event.slash_command: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.slash_command_completion: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.slash_command_error: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], - bot=True, + self_type="AnyBotT", ), Event.user_command: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.user_command_completion: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.user_command_error: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], - bot=True, + self_type="AnyBotT", ), Event.message_command: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.message_command_completion: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]"], - bot=True, + self_type="AnyBotT", ), Event.message_command_error: EventData( arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"], - bot=True, + self_type="AnyBotT", ), } diff --git a/disnake/client.py b/disnake/client.py index 63f1c65f69..416fb69da1 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -121,14 +121,15 @@ from .types.gateway import SessionStartLimit as SessionStartLimitPayload from .voice_client import VoiceProtocol + AnyPrefixBot = Union[commands.Bot, commands.AutoShardedBot] AnyBot = Union[ - commands.Bot, - commands.AutoShardedBot, + AnyPrefixBot, commands.InteractionBot, commands.AutoShardedInteractionBot, ] # we can't use `typing.Self` when the `self: AnyBot` parameter is annotated, # so go back to the old way of using a TypeVar for those overloads + AnyPrefixBotT = TypeVar("AnyPrefixBotT", bound=AnyPrefixBot) AnyBotT = TypeVar("AnyBotT", bound=AnyBot) @@ -2770,34 +2771,36 @@ def wait_for( @overload @_generated def wait_for( - self: AnyBot, + self: AnyPrefixBotT, event: Literal[Event.command, "command"], *, - check: Optional[Callable[[commands.Context], bool]] = None, + check: Optional[Callable[[commands.Context[AnyPrefixBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, commands.Context]: + ) -> Coroutine[Any, Any, commands.Context[AnyPrefixBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyPrefixBotT, event: Literal[Event.command_completion, "command_completion"], *, - check: Optional[Callable[[commands.Context], bool]] = None, + check: Optional[Callable[[commands.Context[AnyPrefixBotT]], bool]] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, commands.Context]: + ) -> Coroutine[Any, Any, commands.Context[AnyPrefixBotT]]: ... @overload @_generated def wait_for( - self: AnyBot, + self: AnyPrefixBotT, event: Literal[Event.command_error, "command_error"], *, - check: Optional[Callable[[commands.Context, commands.CommandError], bool]] = None, + check: Optional[ + Callable[[commands.Context[AnyPrefixBotT], commands.CommandError], bool] + ] = None, timeout: Optional[float] = None, - ) -> Coroutine[Any, Any, Tuple[commands.Context, commands.CommandError]]: + ) -> Coroutine[Any, Any, Tuple[commands.Context[AnyPrefixBotT], commands.CommandError]]: ... @overload diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index bdcbe3cae2..d40b54ce52 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -69,7 +69,6 @@ from ..app_commands import Choices from ..client import Client from ..embeds import Embed - from ..ext.commands import AutoShardedBot, Bot from ..file import File from ..guild import GuildChannel, GuildMessageable from ..mentions import AllowedMentions @@ -90,8 +89,6 @@ InteractionChannel = Union[GuildChannel, Thread, PartialMessageable] - AnyBot = Union[Bot, AutoShardedBot] - MISSING: Any = utils.MISSING diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 68a7b4908c..9bbdb2fcea 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -126,11 +126,10 @@ def generate_wait_for_overload( new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation)) # set `self` annotation as a workaround for overloads in subclasses - if event_data.bot: - self_type = "AnyBotT" if "AnyBotT" in new_annotation_str else "AnyBot" # fun. + if event_data.self_type: new_overload = new_overload.with_deep_changes( get_param(new_overload, "self"), - annotation=cst.Annotation(cst.Name(self_type)), + annotation=cst.Annotation(cst.Name(event_data.self_type)), ) return new_overload From 5bba9e3329526640424e689235b2c7c1cf2aa2a7 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 2 Jan 2024 16:50:25 +0100 Subject: [PATCH 19/22] chore: remove unnecessary type-ignore --- test_bot/cogs/modals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_bot/cogs/modals.py b/test_bot/cogs/modals.py index c5d514a25c..13c84bddf2 100644 --- a/test_bot/cogs/modals.py +++ b/test_bot/cogs/modals.py @@ -65,7 +65,7 @@ async def create_tag_low(self, inter: disnake.AppCmdInter[commands.Bot]) -> None modal_inter: disnake.ModalInteraction = await self.bot.wait_for( "modal_submit", - check=lambda i: i.custom_id == "create_tag2" and i.author.id == inter.author.id, # type: ignore # unknown parameter type + check=lambda i: i.custom_id == "create_tag2" and i.author.id == inter.author.id, ) embed = disnake.Embed(title="Tag Creation") From f0f44e9aa56b956247e135626afa63bf422c9124 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sat, 24 Feb 2024 17:48:24 +0100 Subject: [PATCH 20/22] feat: add new entitlement/presence events --- disnake/_event_data.py | 12 +++++++++++ disnake/client.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index b25c46f3de..9f0765d647 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -299,6 +299,9 @@ def __init__( Event.reaction_clear_emoji: EventData( arg_types=["Reaction"], ), + Event.raw_presence_update: EventData( + arg_types=["RawPresenceUpdateEvent"], + ), Event.raw_reaction_add: EventData( arg_types=["RawReactionActionEvent"], ), @@ -317,6 +320,15 @@ def __init__( Event.raw_typing: EventData( arg_types=["RawTypingEvent"], ), + Event.entitlement_create: EventData( + arg_types=["Entitlement"], + ), + Event.entitlement_update: EventData( + arg_types=["Entitlement"], + ), + Event.entitlement_delete: EventData( + arg_types=["Entitlement"], + ), # bot-only: Event.command: EventData( arg_types=["commands.Context[AnyPrefixBotT]"], diff --git a/disnake/client.py b/disnake/client.py index 5e60998970..ec282f0540 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -109,6 +109,7 @@ RawIntegrationDeleteEvent, RawMessageDeleteEvent, RawMessageUpdateEvent, + RawPresenceUpdateEvent, RawReactionActionEvent, RawReactionClearEmojiEvent, RawReactionClearEvent, @@ -2715,6 +2716,17 @@ def wait_for( ) -> Coroutine[Any, Any, Reaction]: ... + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_presence_update, "raw_presence_update"], + *, + check: Optional[Callable[[RawPresenceUpdateEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawPresenceUpdateEvent]: + ... + @overload @_generated def wait_for( @@ -2785,6 +2797,39 @@ def wait_for( ) -> Coroutine[Any, Any, RawTypingEvent]: ... + @overload + @_generated + def wait_for( + self, + event: Literal[Event.entitlement_create, "entitlement_create"], + *, + check: Optional[Callable[[Entitlement], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Entitlement]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.entitlement_update, "entitlement_update"], + *, + check: Optional[Callable[[Entitlement], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Entitlement]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.entitlement_delete, "entitlement_delete"], + *, + check: Optional[Callable[[Entitlement], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Entitlement]: + ... + @overload @_generated def wait_for( From 251f8ec133a7145037fbfd18af3fde71f3e66904 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sat, 24 Feb 2024 17:57:39 +0100 Subject: [PATCH 21/22] perf(codemod): use `CHECK_MARKER` instead of `tree.code` --- scripts/codemods/typed_events.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/scripts/codemods/typed_events.py b/scripts/codemods/typed_events.py index 9bbdb2fcea..30aae626c2 100644 --- a/scripts/codemods/typed_events.py +++ b/scripts/codemods/typed_events.py @@ -7,11 +7,12 @@ import libcst as cst import libcst.matchers as m -from libcst import codemod from disnake import Event from disnake._event_data import EVENT_DATA, EventData +from .base import BaseCodemodCommand + def get_param(func: cst.FunctionDef, name: str) -> cst.Param: results = m.findall(func.params, m.Param(m.Name(name))) @@ -19,19 +20,13 @@ def get_param(func: cst.FunctionDef, name: str) -> cst.Param: return cast(cst.Param, results[0]) -class EventTypings(codemod.VisitorBasedCodemodCommand): +class EventTypings(BaseCodemodCommand): DESCRIPTION: str = "Adds overloads for library events." + CHECK_MARKER: str = "@_overload_with_events" flag_classes: List[str] imported_module: types.ModuleType - def transform_module(self, tree: cst.Module) -> cst.Module: - if "@_overload_with_events" not in tree.code: - raise codemod.SkipFile( - "this module does not contain the required decorator: `@_overload_with_events`." - ) - return super().transform_module(tree) - def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # don't recurse into the body of a function return False From f248d30be91a61c29c8df70bdf036af9988c6607 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Wed, 11 Dec 2024 18:46:45 +0100 Subject: [PATCH 22/22] feat: add overloads for new events --- disnake/_event_data.py | 18 +++++++++++ disnake/client.py | 71 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/disnake/_event_data.py b/disnake/_event_data.py index 9f0765d647..685ea725c3 100644 --- a/disnake/_event_data.py +++ b/disnake/_event_data.py @@ -236,6 +236,12 @@ def __init__( Event.voice_state_update: EventData( arg_types=["Member", "VoiceState", "VoiceState"], ), + Event.voice_channel_effect: EventData( + arg_types=["GuildChannel", "Member", "VoiceChannelEffect"], + ), + Event.raw_voice_channel_effect: EventData( + arg_types=["RawVoiceChannelEffectEvent"], + ), Event.stage_instance_create: EventData( arg_types=["StageInstance"], ), @@ -278,6 +284,12 @@ def __init__( Event.bulk_message_delete: EventData( arg_types=["List[Message]"], ), + Event.poll_vote_add: EventData( + arg_types=["Member", "PollAnswer"], + ), + Event.poll_vote_remove: EventData( + arg_types=["Member", "PollAnswer"], + ), Event.raw_message_edit: EventData( arg_types=["RawMessageUpdateEvent"], ), @@ -287,6 +299,12 @@ def __init__( Event.raw_bulk_message_delete: EventData( arg_types=["RawBulkMessageDeleteEvent"], ), + Event.raw_poll_vote_add: EventData( + arg_types=["RawPollVoteActionEvent"], + ), + Event.raw_poll_vote_remove: EventData( + arg_types=["RawPollVoteActionEvent"], + ), Event.reaction_add: EventData( arg_types=["Reaction", "Union[Member, User]"], ), diff --git a/disnake/client.py b/disnake/client.py index 5fc61fa34b..4d53ebcd6f 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -91,7 +91,7 @@ from .asset import AssetBytes from .audit_logs import AuditLogEntry from .automod import AutoModActionExecution, AutoModRule - from .channel import DMChannel, ForumChannel, GroupChannel + from .channel import DMChannel, ForumChannel, GroupChannel, VoiceChannelEffect from .guild_scheduled_event import GuildScheduledEvent from .integrations import Integration from .interactions import ( @@ -102,6 +102,7 @@ ) from .member import Member, VoiceState from .message import Message + from .poll import PollAnswer from .raw_models import ( RawBulkMessageDeleteEvent, RawGuildMemberRemoveEvent, @@ -109,6 +110,7 @@ RawIntegrationDeleteEvent, RawMessageDeleteEvent, RawMessageUpdateEvent, + RawPollVoteActionEvent, RawPresenceUpdateEvent, RawReactionActionEvent, RawReactionClearEmojiEvent, @@ -116,6 +118,7 @@ RawThreadDeleteEvent, RawThreadMemberRemoveEvent, RawTypingEvent, + RawVoiceChannelEffectEvent, ) from .reaction import Reaction from .role import Role @@ -2490,6 +2493,28 @@ def wait_for( ) -> Coroutine[Any, Any, Tuple[Member, VoiceState, VoiceState]]: ... + @overload + @_generated + def wait_for( + self, + event: Literal[Event.voice_channel_effect, "voice_channel_effect"], + *, + check: Optional[Callable[[GuildChannel, Member, VoiceChannelEffect], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[GuildChannel, Member, VoiceChannelEffect]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_voice_channel_effect, "raw_voice_channel_effect"], + *, + check: Optional[Callable[[RawVoiceChannelEffectEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawVoiceChannelEffectEvent]: + ... + @overload @_generated def wait_for( @@ -2644,6 +2669,28 @@ def wait_for( ) -> Coroutine[Any, Any, List[Message]]: ... + @overload + @_generated + def wait_for( + self, + event: Literal[Event.poll_vote_add, "poll_vote_add"], + *, + check: Optional[Callable[[Member, PollAnswer], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Member, PollAnswer]]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.poll_vote_remove, "poll_vote_remove"], + *, + check: Optional[Callable[[Member, PollAnswer], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, Tuple[Member, PollAnswer]]: + ... + @overload @_generated def wait_for( @@ -2677,6 +2724,28 @@ def wait_for( ) -> Coroutine[Any, Any, RawBulkMessageDeleteEvent]: ... + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_poll_vote_add, "raw_poll_vote_add"], + *, + check: Optional[Callable[[RawPollVoteActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawPollVoteActionEvent]: + ... + + @overload + @_generated + def wait_for( + self, + event: Literal[Event.raw_poll_vote_remove, "raw_poll_vote_remove"], + *, + check: Optional[Callable[[RawPollVoteActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> Coroutine[Any, Any, RawPollVoteActionEvent]: + ... + @overload @_generated def wait_for(