Skip to content

Commit

Permalink
feat(typing): add generic bot parameter to Context events
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Sep 20, 2023
1 parent 600af60 commit 496902a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 38 deletions.
43 changes: 22 additions & 21 deletions disnake/_event_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

from .enums import Event

Expand All @@ -12,17 +12,17 @@ def __init__(
self,
*,
arg_types: List[str],
bot: bool = False,
self_type: Optional[str] = None,
event_only: bool = False,
) -> None:
self.arg_types: Tuple[str, ...] = tuple(arg_types)
"""Type names of event arguments, e.g. `("Guild", "User")`"""
"""Type names of event arguments, e.g. `("Guild", "User")`."""

self.bot: bool = bot
"""Whether the event is specific to ext.commands"""
self.self_type: Optional[str] = self_type
"""The annotation for the `self` parameter, used for bot-only events."""

self.event_only: bool = event_only
"""Whether the event can only be used through `@event` and not other listeners"""
"""Whether the event can only be used through `@event`, and not with listeners."""


EVENT_DATA: Dict[Event, EventData] = {
Expand Down Expand Up @@ -317,52 +317,53 @@ def __init__(
Event.raw_typing: EventData(
arg_types=["RawTypingEvent"],
),
# bot-only:
Event.command: EventData(
arg_types=["commands.Context"],
bot=True,
arg_types=["commands.Context[AnyPrefixBotT]"],
self_type="AnyPrefixBotT",
),
Event.command_completion: EventData(
arg_types=["commands.Context"],
bot=True,
arg_types=["commands.Context[AnyPrefixBotT]"],
self_type="AnyPrefixBotT",
),
Event.command_error: EventData(
arg_types=["commands.Context", "commands.CommandError"],
bot=True,
arg_types=["commands.Context[AnyPrefixBotT]", "commands.CommandError"],
self_type="AnyPrefixBotT",
),
Event.slash_command: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
self_type="AnyBotT",
),
Event.slash_command_completion: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
self_type="AnyBotT",
),
Event.slash_command_error: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"],
bot=True,
self_type="AnyBotT",
),
Event.user_command: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
self_type="AnyBotT",
),
Event.user_command_completion: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
self_type="AnyBotT",
),
Event.user_command_error: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"],
bot=True,
self_type="AnyBotT",
),
Event.message_command: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
self_type="AnyBotT",
),
Event.message_command_completion: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]"],
bot=True,
self_type="AnyBotT",
),
Event.message_command_error: EventData(
arg_types=["ApplicationCommandInteraction[AnyBotT]", "commands.CommandError"],
bot=True,
self_type="AnyBotT",
),
}
25 changes: 14 additions & 11 deletions disnake/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,15 @@
from .types.gateway import SessionStartLimit as SessionStartLimitPayload
from .voice_client import VoiceProtocol

AnyPrefixBot = Union[commands.Bot, commands.AutoShardedBot]
AnyBot = Union[
commands.Bot,
commands.AutoShardedBot,
AnyPrefixBot,
commands.InteractionBot,
commands.AutoShardedInteractionBot,
]
# we can't use `typing.Self` when the `self: AnyBot` parameter is annotated,
# so go back to the old way of using a TypeVar for those overloads
AnyPrefixBotT = TypeVar("AnyPrefixBotT", bound=AnyPrefixBot)
AnyBotT = TypeVar("AnyBotT", bound=AnyBot)


Expand Down Expand Up @@ -2770,34 +2771,36 @@ def wait_for(
@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyPrefixBotT,
event: Literal[Event.command, "command"],
*,
check: Optional[Callable[[commands.Context], bool]] = None,
check: Optional[Callable[[commands.Context[AnyPrefixBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, commands.Context]:
) -> Coroutine[Any, Any, commands.Context[AnyPrefixBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyPrefixBotT,
event: Literal[Event.command_completion, "command_completion"],
*,
check: Optional[Callable[[commands.Context], bool]] = None,
check: Optional[Callable[[commands.Context[AnyPrefixBotT]], bool]] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, commands.Context]:
) -> Coroutine[Any, Any, commands.Context[AnyPrefixBotT]]:
...

@overload
@_generated
def wait_for(
self: AnyBot,
self: AnyPrefixBotT,
event: Literal[Event.command_error, "command_error"],
*,
check: Optional[Callable[[commands.Context, commands.CommandError], bool]] = None,
check: Optional[
Callable[[commands.Context[AnyPrefixBotT], commands.CommandError], bool]
] = None,
timeout: Optional[float] = None,
) -> Coroutine[Any, Any, Tuple[commands.Context, commands.CommandError]]:
) -> Coroutine[Any, Any, Tuple[commands.Context[AnyPrefixBotT], commands.CommandError]]:
...

@overload
Expand Down
3 changes: 0 additions & 3 deletions disnake/interactions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
from ..app_commands import Choices
from ..client import Client
from ..embeds import Embed
from ..ext.commands import AutoShardedBot, Bot
from ..file import File
from ..guild import GuildChannel, GuildMessageable
from ..mentions import AllowedMentions
Expand All @@ -90,8 +89,6 @@

InteractionChannel = Union[GuildChannel, Thread, PartialMessageable]

AnyBot = Union[Bot, AutoShardedBot]


MISSING: Any = utils.MISSING

Expand Down
5 changes: 2 additions & 3 deletions scripts/codemods/typed_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,10 @@ def generate_wait_for_overload(
new_overload = new_overload.with_changes(returns=cst.Annotation(new_annotation))

# set `self` annotation as a workaround for overloads in subclasses
if event_data.bot:
self_type = "AnyBotT" if "AnyBotT" in new_annotation_str else "AnyBot" # fun.
if event_data.self_type:
new_overload = new_overload.with_deep_changes(
get_param(new_overload, "self"),
annotation=cst.Annotation(cst.Name(self_type)),
annotation=cst.Annotation(cst.Name(event_data.self_type)),
)

return new_overload

0 comments on commit 496902a

Please sign in to comment.