From 352446bc836994143c59b90f7b8ccc41a6cfc016 Mon Sep 17 00:00:00 2001 From: Middledot Date: Fri, 26 Aug 2022 14:45:34 -0400 Subject: [PATCH 01/54] Initial Implementation of Invokables ??? --- discord/abc.py | 401 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 401 insertions(+) diff --git a/discord/abc.py b/discord/abc.py index 00a6a13c10..445fa23d23 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1850,3 +1850,404 @@ async def connect( class Mentionable: # TODO: documentation, methods if needed pass + +import functools +from typing import Coroutine, TypeVar +from .utils import async_all +from .errors import ApplicationCommandError, ApplicationCommandInvokeError +from .ext.commands import ( + DisabledCommand, + CheckFailure, + CommandOnCooldown, + CooldownMapping, + BucketType, + MaxConcurrency, + CommandError +) + + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + from .cog import Cog + + P = ParamSpec("P") +else: + P = TypeVar("P") + +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, "__wrapped__"): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + +def hooked_wrapped_callback(command: Invokable, ctx: ContextT, coro: CallbackT): + @functools.wraps(coro) + async def wrapped(arg): + try: + ret = await coro(arg) + except (ApplicationCommandError, CommandError): + raise + except asyncio.CancelledError: + return + except Exception as exc: + raise ApplicationCommandInvokeError(exc) from exc + finally: + if command._max_concurrency is not None: + await command._max_concurrency.release(ctx) + await command.call_after_hooks(ctx) + + return ret + + return wrapped + + +CallbackT = TypeVar("CallbackT") +ErrorT = TypeVar("ErrorT") +HookT = TypeVar("HookT") +ContextT = TypeVar("ContextT") +MaybeCoro = Union[Any, Coroutine[Any]] +Check = Union[ + Callable[[ContextT], MaybeCoro[bool]], + Callable[[Cog, ContextT], MaybeCoro[bool]] +] + +class Invokable: + _callback: CallbackT + cog: Optional[Cog] + parent: Optional[Invokable] + module: Any + + name: str + enabled: bool + + checks: List[Check] + _buckets: CooldownMapping + _max_concurrency: Optional[MaxConcurrency] + on_error: Optional[ErrorT] + _before_invoke: Optional[HookT] + _after_invoke: Optional[HookT] + + def __init__(self, func: CallbackT, **kwargs): + self.callback = func + self.cog = None + self.module = None + self.enabled = kwargs.get("enabled", True) + + # checks + if checks := getattr(func, "__commands_checks__", []): + checks.reverse() + + checks += kwargs.get("checks", []) # combine all the checks we find (kwargs or decorator) + self.checks = checks + + # cooldowns + cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) + + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets = cooldown + else: + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + + self._buckets = buckets + + # max concurrency + self._max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) + + # hooks + self._before_invoke = None + if hook := getattr(func, "__before_invoke__", None): + self.before_invoke(hook) + + self._after_invoke = None + if hook := getattr(func, "__after_invoke__", None): + self.after_invoke(hook) + + @property + def callback(self) -> CallbackT: + return self._callback + + @callback.setter + def callback(self, func: CallbackT) -> None: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Callback must be a coroutine.") + + self._callback = func + unwrap = unwrap_function(func) + self.module = unwrap.__module__ + + async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): + """|coro| + + Calls the internal callback that the command holds. + + .. note:: + + This bypasses all mechanisms -- including checks, converters, + invoke hooks, cooldowns, etc. You must take care to pass + the proper arguments and types to this function. + + """ + if self.cog is not None: + return await self.callback(self.cog, ctx, *args, **kwargs) + return await self.callback(ctx, *args, **kwargs) + + def error(self, coro: ErrorT) -> ErrorT: + """A decorator that registers a coroutine as a local error handler. + + A local error handler is an :func:`.on_command_error` event limited to + a single command. However, the :func:`.on_command_error` is still + invoked afterwards as the catch-all. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The error handler must be a coroutine.") + + self.on_error = coro + return coro + + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the command has an error handler registered.""" + return hasattr(self, "on_error") + + def before_invoke(self, coro: HookT) -> HookT: + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.before_invoke` for more info. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the pre-invoke hook. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro: HookT) -> HookT: + """A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.after_invoke` for more info. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the post-invoke hook. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + # might require more of a library wide refactoring + + # async def can_run(self, ctx: ContextT) -> bool: + # """|coro| + + # Checks if the command can be executed by checking all the predicates + # inside the :attr:`~Command.checks` attribute. This also checks whether the + # command is disabled. + + # .. versionchanged:: 1.3 + # Checks whether the command is disabled or not + + # Parameters + # ----------- + # ctx: :class:`.Context` + # The ctx of the command currently being invoked. + + # Raises + # ------- + # :class:`CommandError` + # Any command error that was raised during a check call will be propagated + # by this function. + + # Returns + # -------- + # :class:`bool` + # A boolean indicating if the command can be invoked. + # """ + # if not self.enabled: + # raise DisabledCommand(f"{self.name} command is disabled") + + # original = ctx.command + # ctx.command = self + + # predicates = self.checks + # if self.parent is not None: + # # parent checks should be run first + # predicates = self.parent.checks + predicates + + def add_check(self, func: Check) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`.check`. + + Parameters + ----------- + func: Callable + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + Parameters + ----------- + func: Callable + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + + + @property + def qualified_name(self) -> str: + """:class:`str`: Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``?one two three`` the qualified name would be + ``one two three``. + """ + if not self.parent: + return self.name + + return f"{self.parent.qualified_name}" + + def _prepare_cooldowns(self, ctx: ContextT): + if not self._buckets.valid: + return + + current = datetime.datetime.now().timestamp() + bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message + + if bucket: + retry_after = bucket.update_rate_limit(current) + + if retry_after: + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + + async def call_before_hooks(self, ctx) -> None: + # now that we're done preparing we can call the pre-command hooks + # first, call the command local hook: + cog = self.cog + if self._before_invoke is not None: + # should be cog if @commands.before_invoke is used + instance = getattr(self._before_invoke, "__self__", cog) + # __self__ only exists for methods, not functions + # however, if @command.before_invoke is used, it will be a function + if instance: + await self._before_invoke(instance, ctx) # type: ignore + else: + await self._before_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = cog.__class__._get_overridden_method(cog.cog_before_invoke) + if hook is not None: + await hook(ctx) + + # call the bot global hook if necessary + hook = ctx.bot._before_invoke + if hook is not None: + await hook(ctx) + + async def call_after_hooks(self, ctx: Context) -> None: + cog = self.cog + if self._after_invoke is not None: + instance = getattr(self._after_invoke, "__self__", cog) + if instance: + await self._after_invoke(instance, ctx) # type: ignore + else: + await self._after_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = Cog._get_overridden_method(cog.cog_after_invoke) + if hook is not None: + await hook(ctx) + + hook = ctx.bot._after_invoke + if hook is not None: + await hook(ctx) + + async def _parse_arguments(self, ctx: ContextT) -> None: + return + + async def prepare(self, ctx: ContextT) -> None: + ctx.command = self + + if not await self.can_run(ctx): + raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") + + if self._max_concurrency is not None: + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) # type: ignore + + try: + self._prepare_cooldowns(ctx) + await self._parse_arguments(ctx) + + await self.call_before_hooks(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) # type: ignore + raise + + async def invoke(self, ctx: Context) -> None: + await self.prepare(ctx) + + # terminate the invoked_subcommand chain. + # since we're in a regular command (and not a group) then + # the invoked subcommand is None. + # ctx.invoked_subcommand = None + # ctx.subcommand_passed = None + injected = hooked_wrapped_callback(self, ctx, self.callback) + await injected(*ctx.args, **ctx.kwargs) From 0cbe0e7caf4b2901b44eaeb8e2b7897d6871da4e Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 11:08:18 -0400 Subject: [PATCH 02/54] Integrate invokables into app commands --- discord/abc.py | 123 ++++++++-- discord/commands/core.py | 483 +++------------------------------------ 2 files changed, 138 insertions(+), 468 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 445fa23d23..91ea6a30da 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1855,16 +1855,14 @@ class Mentionable: from typing import Coroutine, TypeVar from .utils import async_all from .errors import ApplicationCommandError, ApplicationCommandInvokeError -from .ext.commands import ( - DisabledCommand, - CheckFailure, - CommandOnCooldown, +from .ext.commands.cooldowns import ( CooldownMapping, BucketType, - MaxConcurrency, - CommandError ) +CheckFailure = CommandOnCooldown = MaxConcurrency = CommandError = Exception + + if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -1936,6 +1934,8 @@ def __init__(self, func: CallbackT, **kwargs): self.cog = None self.module = None self.enabled = kwargs.get("enabled", True) + self.name = kwargs.get("name", func.__name__) + self.parent = kwargs.get("parent") # checks if checks := getattr(func, "__commands_checks__", []): @@ -1981,6 +1981,36 @@ def callback(self, func: CallbackT) -> None: unwrap = unwrap_function(func) self.module = unwrap.__module__ + @property + def cooldown(self): + return self._buckets._cooldown + + @property + def full_parent_name(self) -> Optional[str]: + """:class:`str`: Retrieves the fully qualified parent command name. + + This the base command name required to execute it. For example, + in ``/one two three`` the parent name would be ``one two``. + """ + if self.parent: + return self.parent.qualified_name + + @property + def qualified_name(self) -> str: + """:class:`str`: Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``?one two three`` the qualified name would be + ``one two three``. + """ + if not self.parent: + return self.name + + return f"{self.parent.qualified_name} {self.name}" + + def __str__(self) -> str: + return self.qualified_name + async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): """|coro| @@ -2147,20 +2177,6 @@ def remove_check(self, func: Check) -> None: except ValueError: pass - - @property - def qualified_name(self) -> str: - """:class:`str`: Retrieves the fully qualified command name. - - This is the full parent name with the command name as well. - For example, in ``?one two three`` the qualified name would be - ``one two three``. - """ - if not self.parent: - return self.name - - return f"{self.parent.qualified_name}" - def _prepare_cooldowns(self, ctx: ContextT): if not self._buckets.valid: return @@ -2174,7 +2190,68 @@ def _prepare_cooldowns(self, ctx: ContextT): if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - async def call_before_hooks(self, ctx) -> None: + def is_on_cooldown(self, ctx: ContextT) -> bool: + """Checks whether the command is currently on cooldown. + + .. note:: + + This uses the current time instead of the interaction time. + + Parameters + ----------- + ctx: :class:`.ApplicationContext` + The invocation context to use when checking the command's cooldown status. + + Returns + -------- + :class:`bool` + A boolean indicating if the command is on cooldown. + """ + if not self._buckets.valid: + return False + + bucket = self._buckets.get_bucket(ctx) + current = utils.utcnow().timestamp() + return bucket.get_tokens(current) == 0 + + def reset_cooldown(self, ctx) -> None: + """Resets the cooldown on this command. + + Parameters + ----------- + ctx: :class:`.ApplicationContext` + The invocation context to reset the cooldown under. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message + bucket.reset() + + def get_cooldown_retry_after(self, ctx) -> float: + """Retrieves the amount of seconds before this command can be tried again. + + .. note:: + + This uses the current time instead of the interaction time. + + Parameters + ----------- + ctx: :class:`.ApplicationContext` + The invocation context to retrieve the cooldown from. + + Returns + -------- + :class:`float` + The amount of time left on this command's cooldown in seconds. + If this is ``0.0`` then the command isn't on cooldown. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) + current = utils.utcnow().timestamp() + return bucket.get_retry_after(current) + + return 0.0 + + async def call_before_hooks(self, ctx: ContextT) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog @@ -2199,7 +2276,7 @@ async def call_before_hooks(self, ctx) -> None: if hook is not None: await hook(ctx) - async def call_after_hooks(self, ctx: Context) -> None: + async def call_after_hooks(self, ctx: ContextT) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, "__self__", cog) @@ -2241,7 +2318,7 @@ async def prepare(self, ctx: ContextT) -> None: await self._max_concurrency.release(ctx) # type: ignore raise - async def invoke(self, ctx: Context) -> None: + async def invoke(self, ctx: ContextT) -> None: await self.prepare(ctx) # terminate the invoked_subcommand chain. diff --git a/discord/commands/core.py b/discord/commands/core.py index 71398223f6..222410aa15 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -63,6 +63,7 @@ from ..role import Role from ..threads import Thread from ..user import User +from ..abc import Invokable from ..utils import async_all, find, utcnow, maybe_coroutine, MISSING from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice @@ -171,40 +172,12 @@ class _BaseCommand: __slots__ = () -class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]): +class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): __original_kwargs__: Dict[str, Any] cog = None def __init__(self, func: Callable, **kwargs) -> None: - from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency - - cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) - - if cooldown is None: - buckets = CooldownMapping(cooldown, BucketType.default) - elif isinstance(cooldown, CooldownMapping): - buckets = cooldown - else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") - - self._buckets: CooldownMapping = buckets - - max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) - - self._max_concurrency: Optional[MaxConcurrency] = max_concurrency - - self._callback = None - self.module = None - - self.name: str = kwargs.get("name", func.__name__) - - try: - checks = func.__commands_checks__ - checks.reverse() - except AttributeError: - checks = kwargs.get("checks", []) - - self.checks = checks + super().__init__(func, **kwargs) self.id: Optional[int] = kwargs.get("id") self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None) self.parent = kwargs.get("parent") @@ -225,133 +198,6 @@ def __eq__(self, other) -> bool: check = self.name == other.name and self.guild_ids == other.guild_ids return isinstance(other, self.__class__) and self.parent == other.parent and check - async def __call__(self, ctx, *args, **kwargs): - """|coro| - Calls the command's callback. - - This method bypasses all checks that a command has and does not - convert the arguments beforehand, so take care to pass the correct - arguments in. - """ - if self.cog is not None: - return await self.callback(self.cog, ctx, *args, **kwargs) - return await self.callback(ctx, *args, **kwargs) - - @property - def callback( - self, - ) -> Union[ - Callable[[Concatenate[CogT, ApplicationContext, P]], Coro[T]], - Callable[[Concatenate[ApplicationContext, P]], Coro[T]], - ]: - return self._callback - - @callback.setter - def callback( - self, - function: Union[ - Callable[[Concatenate[CogT, ApplicationContext, P]], Coro[T]], - Callable[[Concatenate[ApplicationContext, P]], Coro[T]], - ], - ) -> None: - self._callback = function - unwrap = unwrap_function(function) - self.module = unwrap.__module__ - - def _prepare_cooldowns(self, ctx: ApplicationContext): - if self._buckets.valid: - current = datetime.datetime.now().timestamp() - bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message - - if bucket is not None: - retry_after = bucket.update_rate_limit(current) - - if retry_after: - from ..ext.commands.errors import CommandOnCooldown - - raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - - async def prepare(self, ctx: ApplicationContext) -> None: - # This should be same across all 3 types - ctx.command = self - - if not await self.can_run(ctx): - raise CheckFailure(f"The check functions for the command {self.name} failed") - - if hasattr(self, "_max_concurrency"): - if self._max_concurrency is not None: - # For this application, context can be duck-typed as a Message - await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message - - try: - self._prepare_cooldowns(ctx) - await self.call_before_hooks(ctx) - except: - if self._max_concurrency is not None: - await self._max_concurrency.release(ctx) # type: ignore # ctx instead of non-existent message - raise - - def is_on_cooldown(self, ctx: ApplicationContext) -> bool: - """Checks whether the command is currently on cooldown. - - .. note:: - - This uses the current time instead of the interaction time. - - Parameters - ----------- - ctx: :class:`.ApplicationContext` - The invocation context to use when checking the command's cooldown status. - - Returns - -------- - :class:`bool` - A boolean indicating if the command is on cooldown. - """ - if not self._buckets.valid: - return False - - bucket = self._buckets.get_bucket(ctx) - current = utcnow().timestamp() - return bucket.get_tokens(current) == 0 - - def reset_cooldown(self, ctx: ApplicationContext) -> None: - """Resets the cooldown on this command. - - Parameters - ----------- - ctx: :class:`.ApplicationContext` - The invocation context to reset the cooldown under. - """ - if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message - bucket.reset() - - def get_cooldown_retry_after(self, ctx: ApplicationContext) -> float: - """Retrieves the amount of seconds before this command can be tried again. - - .. note:: - - This uses the current time instead of the interaction time. - - Parameters - ----------- - ctx: :class:`.ApplicationContext` - The invocation context to retrieve the cooldown from. - - Returns - -------- - :class:`float` - The amount of time left on this command's cooldown in seconds. - If this is ``0.0`` then the command isn't on cooldown. - """ - if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx) - current = utcnow().timestamp() - return bucket.get_retry_after(current) - - return 0.0 - async def invoke(self, ctx: ApplicationContext) -> None: await self.prepare(ctx) @@ -359,7 +205,6 @@ async def invoke(self, ctx: ApplicationContext) -> None: await injected(ctx) async def can_run(self, ctx: ApplicationContext) -> bool: - if not await ctx.bot.can_run(ctx): raise CheckFailure(f"The global check functions for command {self.name} failed.") @@ -405,168 +250,45 @@ async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> Non finally: ctx.bot.dispatch("application_command_error", ctx, error) - def _get_signature_parameters(self): - return OrderedDict(inspect.signature(self.callback).parameters) - - def error(self, coro): - """A decorator that registers a coroutine as a local error handler. - - A local error handler is an :func:`.on_command_error` event limited to - a single command. However, the :func:`.on_command_error` is still - invoked afterwards as the catch-all. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the local error handler. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The error handler must be a coroutine.") - - self.on_error = coro - return coro - - def has_error_handler(self) -> bool: - """:class:`bool`: Checks whether the command has an error handler registered.""" - return hasattr(self, "on_error") - - def before_invoke(self, coro): - """A decorator that registers a coroutine as a pre-invoke hook. - A pre-invoke hook is called directly before the command is - called. This makes it a useful function to set up database - connections or any type of set up required. - - This pre-invoke hook takes a sole parameter, a :class:`.ApplicationContext`. - See :meth:`.Bot.before_invoke` for more info. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the pre-invoke hook. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The pre-invoke hook must be a coroutine.") - - self._before_invoke = coro - return coro - - def after_invoke(self, coro): - """A decorator that registers a coroutine as a post-invoke hook. - A post-invoke hook is called directly after the command is - called. This makes it a useful function to clean-up database - connections or any type of clean up required. - - This post-invoke hook takes a sole parameter, a :class:`.ApplicationContext`. - See :meth:`.Bot.after_invoke` for more info. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the post-invoke hook. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The post-invoke hook must be a coroutine.") - - self._after_invoke = coro - return coro - - async def call_before_hooks(self, ctx: ApplicationContext) -> None: - # now that we're done preparing we can call the pre-command hooks - # first, call the command local hook: - cog = self.cog - if self._before_invoke is not None: - # should be cog if @commands.before_invoke is used - instance = getattr(self._before_invoke, "__self__", cog) - # __self__ only exists for methods, not functions - # however, if @command.before_invoke is used, it will be a function - if instance: - await self._before_invoke(instance, ctx) # type: ignore - else: - await self._before_invoke(ctx) # type: ignore - - # call the cog local hook if applicable: - if cog is not None: - hook = cog.__class__._get_overridden_method(cog.cog_before_invoke) - if hook is not None: - await hook(ctx) - - # call the bot global hook if necessary - hook = ctx.bot._before_invoke - if hook is not None: - await hook(ctx) - - async def call_after_hooks(self, ctx: ApplicationContext) -> None: - cog = self.cog - if self._after_invoke is not None: - instance = getattr(self._after_invoke, "__self__", cog) - if instance: - await self._after_invoke(instance, ctx) # type: ignore - else: - await self._after_invoke(ctx) # type: ignore - - # call the cog local hook if applicable: - if cog is not None: - hook = cog.__class__._get_overridden_method(cog.cog_after_invoke) - if hook is not None: - await hook(ctx) - - hook = ctx.bot._after_invoke - if hook is not None: - await hook(ctx) - - @property - def cooldown(self): - return self._buckets._cooldown - - @property - def full_parent_name(self) -> str: - """:class:`str`: Retrieves the fully qualified parent command name. + def copy(self): + """Creates a copy of this command. - This the base command name required to execute it. For example, - in ``/one two three`` the parent name would be ``one two``. + Returns + -------- + :class:`SlashCommand` + A new instance of this command. """ - entries = [] - command = self - while command.parent is not None and hasattr(command.parent, "name"): - command = command.parent - entries.append(command.name) - - return " ".join(reversed(entries)) - - @property - def qualified_name(self) -> str: - """:class:`str`: Retrieves the fully qualified command name. + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) - This is the full parent name with the command name as well. - For example, in ``/one two three`` the qualified name would be - ``one two three``. - """ + def _ensure_assignment_on_copy(self, other): + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + if self._buckets.valid and not other._buckets.valid: + other._buckets = self._buckets.copy() + if self._max_concurrency != other._max_concurrency: + # _max_concurrency won't be None at this point + other._max_concurrency = self._max_concurrency.copy() # type: ignore - parent = self.full_parent_name + try: + other.on_error = self.on_error + except AttributeError: + pass + return other - if parent: - return f"{parent} {self.name}" + def _update_copy(self, kwargs: Dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) else: - return self.name + return self.copy() - def __str__(self) -> str: - return self.qualified_name + def _get_signature_parameters(self): + return OrderedDict(inspect.signature(self.callback).parameters) def _set_cog(self, cog): self.cog = cog @@ -630,9 +352,6 @@ def __new__(cls, *args, **kwargs) -> SlashCommand: def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback must be a coroutine.") - self.callback = func self.name_localizations: Optional[Dict[str, str]] = kwargs.get("name_localizations", None) _validate_names(self) @@ -649,17 +368,6 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self.options: List[Option] = kwargs.get("options", []) - try: - checks = func.__commands_checks__ - checks.reverse() - except AttributeError: - checks = kwargs.get("checks", []) - - self.checks = checks - - self._before_invoke = None - self._after_invoke = None - self._cog = MISSING def _validate_parameters(self): @@ -805,7 +513,7 @@ def to_dict(self) -> Dict: return as_dict - async def _invoke(self, ctx: ApplicationContext) -> None: + async def _parse_arguments(self, ctx: ApplicationContext) -> None: # TODO: Parse the args better kwargs = {} for arg in ctx.interaction.data.get("options", []): @@ -900,12 +608,8 @@ async def _invoke(self, ctx: ApplicationContext) -> None: if o._parameter_name not in kwargs: kwargs[o._parameter_name] = o.default - if self.cog is not None: - await self.callback(self.cog, ctx, **kwargs) - elif self.parent is not None and self.attached_to_group is True: - await self.callback(self.parent, ctx, **kwargs) - else: - await self.callback(ctx, **kwargs) + ctx.args = [] + ctx.kwargs = kwargs async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): values = {i.name: i.default for i in self.options} @@ -931,43 +635,6 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): choices = [o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result][:25] return await ctx.interaction.response.send_autocomplete_result(choices=choices) - def copy(self): - """Creates a copy of this command. - - Returns - -------- - :class:`SlashCommand` - A new instance of this command. - """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) - - def _ensure_assignment_on_copy(self, other): - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - # if self._buckets.valid and not other._buckets.valid: - # other._buckets = self._buckets.copy() - # if self._max_concurrency != other._max_concurrency: - # # _max_concurrency won't be None at this point - # other._max_concurrency = self._max_concurrency.copy() # type: ignore - - try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def _update_copy(self, kwargs: Dict[str, Any]): - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - def _set_cog(self, cog): super()._set_cog(cog) self._validate_parameters() @@ -1446,43 +1113,6 @@ async def _invoke(self, ctx: ApplicationContext) -> None: else: await self.callback(ctx, target) - def copy(self): - """Creates a copy of this command. - - Returns - -------- - :class:`UserCommand` - A new instance of this command. - """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) - - def _ensure_assignment_on_copy(self, other): - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - # if self._buckets.valid and not other._buckets.valid: - # other._buckets = self._buckets.copy() - # if self._max_concurrency != other._max_concurrency: - # # _max_concurrency won't be None at this point - # other._max_concurrency = self._max_concurrency.copy() # type: ignore - - try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def _update_copy(self, kwargs: Dict[str, Any]): - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - class MessageCommand(ContextMenuCommand): r"""A class that implements the protocol for message context menu commands. @@ -1543,43 +1173,6 @@ async def _invoke(self, ctx: ApplicationContext): else: await self.callback(ctx, target) - def copy(self): - """Creates a copy of this command. - - Returns - -------- - :class:`MessageCommand` - A new instance of this command. - """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) - - def _ensure_assignment_on_copy(self, other): - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - # if self._buckets.valid and not other._buckets.valid: - # other._buckets = self._buckets.copy() - # if self._max_concurrency != other._max_concurrency: - # # _max_concurrency won't be None at this point - # other._max_concurrency = self._max_concurrency.copy() # type: ignore - - try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def _update_copy(self, kwargs: Dict[str, Any]): - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - def slash_command(**kwargs): """Decorator for slash commands that invokes :func:`application_command`. From 915534e8e956533448a81f54397d5cfa92a367a9 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 16:05:57 -0400 Subject: [PATCH 03/54] Merge branch 'master' into invokable ik fake merge commit but like I can't merge anything really --- discord/commands/core.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index 222410aa15..4b0dd48721 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -635,10 +635,6 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): choices = [o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result][:25] return await ctx.interaction.response.send_autocomplete_result(choices=choices) - def _set_cog(self, cog): - super()._set_cog(cog) - self._validate_parameters() - class SlashCommandGroup(ApplicationCommand): r"""A class that implements the protocol for a slash command group. @@ -757,10 +753,16 @@ def to_dict(self) -> Dict: return as_dict + def add_command(self, command: SlashCommand) -> None: + if command.cog is MISSING: + command.cog = self.cog + + self.subcommands.append(command) + def command(self, cls: Type[T] = SlashCommand, **kwargs) -> Callable[[Callable], SlashCommand]: def wrap(func) -> T: command = cls(func, parent=self, **kwargs) - self.subcommands.append(command) + self.add_command(command) return command return wrap From 70ed7d5513d066418b6d5e054dd0694bf074c21a Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 17:05:07 -0400 Subject: [PATCH 04/54] Make fully working with app commands * added can_run method * move cooldowns/max conc. * copy methods * compatibility fixes etc (including circular imports) * it works now btw --- discord/abc.py | 478 ------------------------ discord/bot.py | 1 - discord/commands/cooldowns.py | 391 ++++++++++++++++++++ discord/commands/core.py | 153 +------- discord/commands/invokable.py | 586 ++++++++++++++++++++++++++++++ discord/errors.py | 122 ++++++- discord/ext/commands/cooldowns.py | 359 +----------------- discord/ext/commands/errors.py | 110 +----- 8 files changed, 1101 insertions(+), 1099 deletions(-) create mode 100644 discord/commands/cooldowns.py create mode 100644 discord/commands/invokable.py diff --git a/discord/abc.py b/discord/abc.py index 91ea6a30da..00a6a13c10 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1850,481 +1850,3 @@ async def connect( class Mentionable: # TODO: documentation, methods if needed pass - -import functools -from typing import Coroutine, TypeVar -from .utils import async_all -from .errors import ApplicationCommandError, ApplicationCommandInvokeError -from .ext.commands.cooldowns import ( - CooldownMapping, - BucketType, -) - -CheckFailure = CommandOnCooldown = MaxConcurrency = CommandError = Exception - - - -if TYPE_CHECKING: - from typing_extensions import ParamSpec - from .cog import Cog - - P = ParamSpec("P") -else: - P = TypeVar("P") - -def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: - partial = functools.partial - while True: - if hasattr(function, "__wrapped__"): - function = function.__wrapped__ - elif isinstance(function, partial): - function = function.func - else: - return function - -def hooked_wrapped_callback(command: Invokable, ctx: ContextT, coro: CallbackT): - @functools.wraps(coro) - async def wrapped(arg): - try: - ret = await coro(arg) - except (ApplicationCommandError, CommandError): - raise - except asyncio.CancelledError: - return - except Exception as exc: - raise ApplicationCommandInvokeError(exc) from exc - finally: - if command._max_concurrency is not None: - await command._max_concurrency.release(ctx) - await command.call_after_hooks(ctx) - - return ret - - return wrapped - - -CallbackT = TypeVar("CallbackT") -ErrorT = TypeVar("ErrorT") -HookT = TypeVar("HookT") -ContextT = TypeVar("ContextT") -MaybeCoro = Union[Any, Coroutine[Any]] -Check = Union[ - Callable[[ContextT], MaybeCoro[bool]], - Callable[[Cog, ContextT], MaybeCoro[bool]] -] - -class Invokable: - _callback: CallbackT - cog: Optional[Cog] - parent: Optional[Invokable] - module: Any - - name: str - enabled: bool - - checks: List[Check] - _buckets: CooldownMapping - _max_concurrency: Optional[MaxConcurrency] - on_error: Optional[ErrorT] - _before_invoke: Optional[HookT] - _after_invoke: Optional[HookT] - - def __init__(self, func: CallbackT, **kwargs): - self.callback = func - self.cog = None - self.module = None - self.enabled = kwargs.get("enabled", True) - self.name = kwargs.get("name", func.__name__) - self.parent = kwargs.get("parent") - - # checks - if checks := getattr(func, "__commands_checks__", []): - checks.reverse() - - checks += kwargs.get("checks", []) # combine all the checks we find (kwargs or decorator) - self.checks = checks - - # cooldowns - cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) - - if cooldown is None: - buckets = CooldownMapping(cooldown, BucketType.default) - elif isinstance(cooldown, CooldownMapping): - buckets = cooldown - else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") - - self._buckets = buckets - - # max concurrency - self._max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) - - # hooks - self._before_invoke = None - if hook := getattr(func, "__before_invoke__", None): - self.before_invoke(hook) - - self._after_invoke = None - if hook := getattr(func, "__after_invoke__", None): - self.after_invoke(hook) - - @property - def callback(self) -> CallbackT: - return self._callback - - @callback.setter - def callback(self, func: CallbackT) -> None: - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback must be a coroutine.") - - self._callback = func - unwrap = unwrap_function(func) - self.module = unwrap.__module__ - - @property - def cooldown(self): - return self._buckets._cooldown - - @property - def full_parent_name(self) -> Optional[str]: - """:class:`str`: Retrieves the fully qualified parent command name. - - This the base command name required to execute it. For example, - in ``/one two three`` the parent name would be ``one two``. - """ - if self.parent: - return self.parent.qualified_name - - @property - def qualified_name(self) -> str: - """:class:`str`: Retrieves the fully qualified command name. - - This is the full parent name with the command name as well. - For example, in ``?one two three`` the qualified name would be - ``one two three``. - """ - if not self.parent: - return self.name - - return f"{self.parent.qualified_name} {self.name}" - - def __str__(self) -> str: - return self.qualified_name - - async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): - """|coro| - - Calls the internal callback that the command holds. - - .. note:: - - This bypasses all mechanisms -- including checks, converters, - invoke hooks, cooldowns, etc. You must take care to pass - the proper arguments and types to this function. - - """ - if self.cog is not None: - return await self.callback(self.cog, ctx, *args, **kwargs) - return await self.callback(ctx, *args, **kwargs) - - def error(self, coro: ErrorT) -> ErrorT: - """A decorator that registers a coroutine as a local error handler. - - A local error handler is an :func:`.on_command_error` event limited to - a single command. However, the :func:`.on_command_error` is still - invoked afterwards as the catch-all. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the local error handler. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The error handler must be a coroutine.") - - self.on_error = coro - return coro - - def has_error_handler(self) -> bool: - """:class:`bool`: Checks whether the command has an error handler registered.""" - return hasattr(self, "on_error") - - def before_invoke(self, coro: HookT) -> HookT: - """A decorator that registers a coroutine as a pre-invoke hook. - - A pre-invoke hook is called directly before the command is - called. This makes it a useful function to set up database - connections or any type of set up required. - - This pre-invoke hook takes a sole parameter, a :class:`.Context`. - - See :meth:`.Bot.before_invoke` for more info. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the pre-invoke hook. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The pre-invoke hook must be a coroutine.") - - self._before_invoke = coro - return coro - - def after_invoke(self, coro: HookT) -> HookT: - """A decorator that registers a coroutine as a post-invoke hook. - - A post-invoke hook is called directly after the command is - called. This makes it a useful function to clean-up database - connections or any type of clean up required. - - This post-invoke hook takes a sole parameter, a :class:`.Context`. - - See :meth:`.Bot.after_invoke` for more info. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the post-invoke hook. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The post-invoke hook must be a coroutine.") - - self._after_invoke = coro - return coro - - # might require more of a library wide refactoring - - # async def can_run(self, ctx: ContextT) -> bool: - # """|coro| - - # Checks if the command can be executed by checking all the predicates - # inside the :attr:`~Command.checks` attribute. This also checks whether the - # command is disabled. - - # .. versionchanged:: 1.3 - # Checks whether the command is disabled or not - - # Parameters - # ----------- - # ctx: :class:`.Context` - # The ctx of the command currently being invoked. - - # Raises - # ------- - # :class:`CommandError` - # Any command error that was raised during a check call will be propagated - # by this function. - - # Returns - # -------- - # :class:`bool` - # A boolean indicating if the command can be invoked. - # """ - # if not self.enabled: - # raise DisabledCommand(f"{self.name} command is disabled") - - # original = ctx.command - # ctx.command = self - - # predicates = self.checks - # if self.parent is not None: - # # parent checks should be run first - # predicates = self.parent.checks + predicates - - def add_check(self, func: Check) -> None: - """Adds a check to the command. - - This is the non-decorator interface to :func:`.check`. - - Parameters - ----------- - func: Callable - The function that will be used as a check. - """ - - self.checks.append(func) - - def remove_check(self, func: Check) -> None: - """Removes a check from the command. - - This function is idempotent and will not raise an exception - if the function is not in the command's checks. - - Parameters - ----------- - func: Callable - The function to remove from the checks. - """ - - try: - self.checks.remove(func) - except ValueError: - pass - - def _prepare_cooldowns(self, ctx: ContextT): - if not self._buckets.valid: - return - - current = datetime.datetime.now().timestamp() - bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message - - if bucket: - retry_after = bucket.update_rate_limit(current) - - if retry_after: - raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - - def is_on_cooldown(self, ctx: ContextT) -> bool: - """Checks whether the command is currently on cooldown. - - .. note:: - - This uses the current time instead of the interaction time. - - Parameters - ----------- - ctx: :class:`.ApplicationContext` - The invocation context to use when checking the command's cooldown status. - - Returns - -------- - :class:`bool` - A boolean indicating if the command is on cooldown. - """ - if not self._buckets.valid: - return False - - bucket = self._buckets.get_bucket(ctx) - current = utils.utcnow().timestamp() - return bucket.get_tokens(current) == 0 - - def reset_cooldown(self, ctx) -> None: - """Resets the cooldown on this command. - - Parameters - ----------- - ctx: :class:`.ApplicationContext` - The invocation context to reset the cooldown under. - """ - if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message - bucket.reset() - - def get_cooldown_retry_after(self, ctx) -> float: - """Retrieves the amount of seconds before this command can be tried again. - - .. note:: - - This uses the current time instead of the interaction time. - - Parameters - ----------- - ctx: :class:`.ApplicationContext` - The invocation context to retrieve the cooldown from. - - Returns - -------- - :class:`float` - The amount of time left on this command's cooldown in seconds. - If this is ``0.0`` then the command isn't on cooldown. - """ - if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx) - current = utils.utcnow().timestamp() - return bucket.get_retry_after(current) - - return 0.0 - - async def call_before_hooks(self, ctx: ContextT) -> None: - # now that we're done preparing we can call the pre-command hooks - # first, call the command local hook: - cog = self.cog - if self._before_invoke is not None: - # should be cog if @commands.before_invoke is used - instance = getattr(self._before_invoke, "__self__", cog) - # __self__ only exists for methods, not functions - # however, if @command.before_invoke is used, it will be a function - if instance: - await self._before_invoke(instance, ctx) # type: ignore - else: - await self._before_invoke(ctx) # type: ignore - - # call the cog local hook if applicable: - if cog is not None: - hook = cog.__class__._get_overridden_method(cog.cog_before_invoke) - if hook is not None: - await hook(ctx) - - # call the bot global hook if necessary - hook = ctx.bot._before_invoke - if hook is not None: - await hook(ctx) - - async def call_after_hooks(self, ctx: ContextT) -> None: - cog = self.cog - if self._after_invoke is not None: - instance = getattr(self._after_invoke, "__self__", cog) - if instance: - await self._after_invoke(instance, ctx) # type: ignore - else: - await self._after_invoke(ctx) # type: ignore - - # call the cog local hook if applicable: - if cog is not None: - hook = Cog._get_overridden_method(cog.cog_after_invoke) - if hook is not None: - await hook(ctx) - - hook = ctx.bot._after_invoke - if hook is not None: - await hook(ctx) - - async def _parse_arguments(self, ctx: ContextT) -> None: - return - - async def prepare(self, ctx: ContextT) -> None: - ctx.command = self - - if not await self.can_run(ctx): - raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") - - if self._max_concurrency is not None: - # For this application, context can be duck-typed as a Message - await self._max_concurrency.acquire(ctx) # type: ignore - - try: - self._prepare_cooldowns(ctx) - await self._parse_arguments(ctx) - - await self.call_before_hooks(ctx) - except: - if self._max_concurrency is not None: - await self._max_concurrency.release(ctx) # type: ignore - raise - - async def invoke(self, ctx: ContextT) -> None: - await self.prepare(ctx) - - # terminate the invoked_subcommand chain. - # since we're in a regular command (and not a group) then - # the invoked subcommand is None. - # ctx.invoked_subcommand = None - # ctx.subcommand_passed = None - injected = hooked_wrapped_callback(self, ctx, self.callback) - await injected(*ctx.args, **ctx.kwargs) diff --git a/discord/bot.py b/discord/bot.py index 89f4ec7b2e..2354806a15 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -230,7 +230,6 @@ async def get_desynced_commands( .. versionadded:: 2.0 - Parameters ---------- guild_id: Optional[:class:`int`] diff --git a/discord/commands/cooldowns.py b/discord/commands/cooldowns.py new file mode 100644 index 0000000000..3678d45e7c --- /dev/null +++ b/discord/commands/cooldowns.py @@ -0,0 +1,391 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque +from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Optional, Type, TypeVar + +from discord.enums import Enum + +from ..abc import PrivateChannel +from ..errors import MaxConcurrencyReached + +if TYPE_CHECKING: + from ..message import Message + +__all__ = ( + "BucketType", + "Cooldown", + "CooldownMapping", + "DynamicCooldownMapping", + "MaxConcurrency", +) + +C = TypeVar("C", bound="CooldownMapping") +MC = TypeVar("MC", bound="MaxConcurrency") + + +class BucketType(Enum): + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 + category = 5 + role = 6 + + def get_key(self, msg: Message) -> Any: + if self is BucketType.user: + return msg.author.id + elif self is BucketType.guild: + return (msg.guild or msg.author).id + elif self is BucketType.channel: + return msg.channel.id + elif self is BucketType.member: + return (msg.guild and msg.guild.id), msg.author.id + elif self is BucketType.category: + return (msg.channel.category or msg.channel).id # type: ignore + elif self is BucketType.role: + # we return the channel id of a private-channel as there are only roles in guilds + # and that yields the same result as for a guild with only the @everyone role + # NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are + # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do + return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore + + def __call__(self, msg: Message) -> Any: + return self.get_key(msg) + + +class Cooldown: + """Represents a cooldown for a command. + + Attributes + ----------- + rate: :class:`int` + The total number of tokens available per :attr:`per` seconds. + per: :class:`float` + The length of the cooldown period in seconds. + """ + + __slots__ = ("rate", "per", "_window", "_tokens", "_last") + + def __init__(self, rate: float, per: float) -> None: + self.rate: int = int(rate) + self.per: float = float(per) + self._window: float = 0.0 + self._tokens: int = self.rate + self._last: float = 0.0 + + def get_tokens(self, current: Optional[float] = None) -> int: + """Returns the number of available tokens before rate limiting is applied. + + Parameters + ------------ + current: Optional[:class:`float`] + The time in seconds since Unix epoch to calculate tokens at. + If not supplied then :func:`time.time()` is used. + + Returns + -------- + :class:`int` + The number of tokens available before the cooldown is to be applied. + """ + if not current: + current = time.time() + + tokens = self._tokens + + if current > self._window + self.per: + tokens = self.rate + return tokens + + def get_retry_after(self, current: Optional[float] = None) -> float: + """Returns the time in seconds until the cooldown will be reset. + + Parameters + ------------- + current: Optional[:class:`float`] + The current time in seconds since Unix epoch. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + :class:`float` + The number of seconds to wait before this cooldown will be reset. + """ + current = current or time.time() + tokens = self.get_tokens(current) + + if tokens == 0: + return self.per - (current - self._window) + + return 0.0 + + def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: + """Updates the cooldown rate limit. + + Parameters + ------------- + current: Optional[:class:`float`] + The time in seconds since Unix epoch to update the rate limit at. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + Optional[:class:`float`] + The retry-after time in seconds if rate limited. + """ + current = current or time.time() + self._last = current + + self._tokens = self.get_tokens(current) + + # first token used means that we start a new rate limit window + if self._tokens == self.rate: + self._window = current + + # check if we are rate limited + if self._tokens == 0: + return self.per - (current - self._window) + + # we're not so decrement our tokens + self._tokens -= 1 + + def reset(self) -> None: + """Reset the cooldown to its initial state.""" + self._tokens = self.rate + self._last = 0.0 + + def copy(self) -> Cooldown: + """Creates a copy of this cooldown. + + Returns + -------- + :class:`Cooldown` + A new instance of this cooldown. + """ + return Cooldown(self.rate, self.per) + + def __repr__(self) -> str: + return f"" + + +class CooldownMapping: + def __init__( + self, + original: Optional[Cooldown], + type: Callable[[Message], Any], + ) -> None: + if not callable(type): + raise TypeError("Cooldown type must be a BucketType or callable") + + self._cache: Dict[Any, Cooldown] = {} + self._cooldown: Optional[Cooldown] = original + self._type: Callable[[Message], Any] = type + + def copy(self) -> CooldownMapping: + ret = CooldownMapping(self._cooldown, self._type) + ret._cache = self._cache.copy() + return ret + + @property + def valid(self) -> bool: + return self._cooldown is not None + + @property + def type(self) -> Callable[[Message], Any]: + return self._type + + @classmethod + def from_cooldown(cls: Type[C], rate, per, type) -> C: + return cls(Cooldown(rate, per), type) + + def _bucket_key(self, msg: Message) -> Any: + return self._type(msg) + + def _verify_cache_integrity(self, current: Optional[float] = None) -> None: + # we want to delete all cache objects that haven't been used + # in a cooldown window. e.g. if we have a command that has a + # cooldown of 60s, and it has not been used in 60s then that key should be deleted + current = current or time.time() + dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per] + for k in dead_keys: + del self._cache[k] + + def create_bucket(self, message: Message) -> Cooldown: + return self._cooldown.copy() # type: ignore + + def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown: + if self._type is BucketType.default: + return self._cooldown # type: ignore + + self._verify_cache_integrity(current) + key = self._bucket_key(message) + if key not in self._cache: + bucket = self.create_bucket(message) + if bucket is not None: + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket + + def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: + bucket = self.get_bucket(message, current) + return bucket.update_rate_limit(current) + + +class DynamicCooldownMapping(CooldownMapping): + def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None: + super().__init__(None, type) + self._factory: Callable[[Message], Cooldown] = factory + + def copy(self) -> DynamicCooldownMapping: + ret = DynamicCooldownMapping(self._factory, self._type) + ret._cache = self._cache.copy() + return ret + + @property + def valid(self) -> bool: + return True + + def create_bucket(self, message: Message) -> Cooldown: + return self._factory(message) + + +class _Semaphore: + """This class is a version of a semaphore. + + If you're wondering why asyncio.Semaphore isn't being used, + it's because it doesn't expose the internal value. This internal + value is necessary because I need to support both `wait=True` and + `wait=False`. + + An asyncio.Queue could have been used to do this as well -- but it is + not as inefficient since internally that uses two queues and is a bit + overkill for what is basically a counter. + """ + + __slots__ = ("value", "loop", "_waiters") + + def __init__(self, number: int) -> None: + self.value: int = number + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._waiters: Deque[asyncio.Future] = deque() + + def __repr__(self) -> str: + return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>" + + def locked(self) -> bool: + return self.value == 0 + + def is_active(self) -> bool: + return len(self._waiters) > 0 + + def wake_up(self) -> None: + while self._waiters: + future = self._waiters.popleft() + if not future.done(): + future.set_result(None) + return + + async def acquire(self, *, wait: bool = False) -> bool: + if not wait and self.value <= 0: + # signal that we're not acquiring + return False + + while self.value <= 0: + future = self.loop.create_future() + self._waiters.append(future) + try: + await future + except: + future.cancel() + if self.value > 0 and not future.cancelled(): + self.wake_up() + raise + + self.value -= 1 + return True + + def release(self) -> None: + self.value += 1 + self.wake_up() + + +class MaxConcurrency: + __slots__ = ("number", "per", "wait", "_mapping") + + def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: + self._mapping: Dict[Any, _Semaphore] = {} + self.per: BucketType = per + self.number: int = number + self.wait: bool = wait + + if number <= 0: + raise ValueError("max_concurrency 'number' cannot be less than 1") + + if not isinstance(per, BucketType): + raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") + + def copy(self: MC) -> MC: + return self.__class__(self.number, per=self.per, wait=self.wait) + + def __repr__(self) -> str: + return f"" + + def get_key(self, message: Message) -> Any: + return self.per.get_key(message) + + async def acquire(self, message: Message) -> None: + key = self.get_key(message) + + try: + sem = self._mapping[key] + except KeyError: + self._mapping[key] = sem = _Semaphore(self.number) + + acquired = await sem.acquire(wait=self.wait) + if not acquired: + raise MaxConcurrencyReached(self.number, self.per) + + async def release(self, message: Message) -> None: + # Technically there's no reason for this function to be async + # But it might be more useful in the future + key = self.get_key(message) + + try: + sem = self._mapping[key] + except KeyError: + # ...? peculiar + return + else: + sem.release() + + if sem.value >= self.number and not sem.is_active(): + del self._mapping[key] diff --git a/discord/commands/core.py b/discord/commands/core.py index 4b0dd48721..2bf3d84978 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -63,8 +63,8 @@ from ..role import Role from ..threads import Thread from ..user import User -from ..abc import Invokable -from ..utils import async_all, find, utcnow, maybe_coroutine, MISSING +from .invokable import Invokable +from ..utils import async_all, find, maybe_coroutine, MISSING from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice @@ -84,7 +84,7 @@ ) if TYPE_CHECKING: - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from .. import Permissions from ..cog import Cog @@ -198,95 +198,6 @@ def __eq__(self, other) -> bool: check = self.name == other.name and self.guild_ids == other.guild_ids return isinstance(other, self.__class__) and self.parent == other.parent and check - async def invoke(self, ctx: ApplicationContext) -> None: - await self.prepare(ctx) - - injected = hooked_wrapped_callback(self, ctx, self._invoke) - await injected(ctx) - - async def can_run(self, ctx: ApplicationContext) -> bool: - if not await ctx.bot.can_run(ctx): - raise CheckFailure(f"The global check functions for command {self.name} failed.") - - predicates = self.checks - if self.parent is not None: - # parent checks should be run first - predicates = self.parent.checks + predicates - - cog = self.cog - if cog is not None: - local_check = cog._get_overridden_method(cog.cog_check) - if local_check is not None: - ret = await maybe_coroutine(local_check, ctx) - if not ret: - return False - - if not predicates: - # since we have no checks, then we just return True. - return True - - return await async_all(predicate(ctx) for predicate in predicates) # type: ignore - - async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> None: - ctx.command_failed = True - cog = self.cog - try: - coro = self.on_error - except AttributeError: - pass - else: - injected = wrap_callback(coro) - if cog is not None: - await injected(cog, ctx, error) - else: - await injected(ctx, error) - - try: - if cog is not None: - local = cog.__class__._get_overridden_method(cog.cog_command_error) - if local is not None: - wrapped = wrap_callback(local) - await wrapped(ctx, error) - finally: - ctx.bot.dispatch("application_command_error", ctx, error) - - def copy(self): - """Creates a copy of this command. - - Returns - -------- - :class:`SlashCommand` - A new instance of this command. - """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) - - def _ensure_assignment_on_copy(self, other): - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - if self._buckets.valid and not other._buckets.valid: - other._buckets = self._buckets.copy() - if self._max_concurrency != other._max_concurrency: - # _max_concurrency won't be None at this point - other._max_concurrency = self._max_concurrency.copy() # type: ignore - - try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def _update_copy(self, kwargs: Dict[str, Any]): - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) @@ -608,7 +519,7 @@ async def _parse_arguments(self, ctx: ApplicationContext) -> None: if o._parameter_name not in kwargs: kwargs[o._parameter_name] = o.default - ctx.args = [] + ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = kwargs async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): @@ -862,7 +773,7 @@ def inner(cls: Type[SlashCommandGroup]) -> SlashCommandGroup: return inner - async def _invoke(self, ctx: ApplicationContext) -> None: + async def invoke(self, ctx: ApplicationContext) -> None: option = ctx.interaction.data["options"][0] resolved = ctx.interaction.data.get("resolved", None) command = find(lambda x: x.name == option["name"], self.subcommands) @@ -889,48 +800,6 @@ def walk_commands(self) -> Generator[SlashCommand, None, None]: yield from command.walk_commands() yield command - def copy(self): - """Creates a copy of this command group. - - Returns - -------- - :class:`SlashCommandGroup` - A new instance of this command group. - """ - ret = self.__class__( - name=self.name, - description=self.description, - **{ - param: value - for param, value in self.__original_kwargs__.items() - if param not in ("name", "description") - }, - ) - return self._ensure_assignment_on_copy(ret) - - def _ensure_assignment_on_copy(self, other): - other.parent = self.parent - - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - - if self.subcommands != other.subcommands: - other.subcommands = self.subcommands.copy() - - if self.checks != other.checks: - other.checks = self.checks.copy() - - return other - - def _update_copy(self, kwargs: Dict[str, Any]): - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - def _set_cog(self, cog): super()._set_cog(cog) for subcommand in self.subcommands: @@ -982,23 +851,15 @@ def __new__(cls, *args, **kwargs) -> ContextMenuCommand: def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback must be a coroutine.") - self.callback = func self.name_localizations: Optional[Dict[str, str]] = kwargs.get("name_localizations", None) # Discord API doesn't support setting descriptions for context menu commands, so it must be empty self.description = "" - if not isinstance(self.name, str): - raise TypeError("Name of a command must be a string.") self.cog = None self.id = None - self._before_invoke = None - self._after_invoke = None - self.validate_parameters() # Context Menu commands can't have parents @@ -1032,10 +893,6 @@ def validate_parameters(self): except StopIteration: pass - @property - def qualified_name(self): - return self.name - def to_dict(self) -> Dict[str, Union[str, int]]: as_dict = { "name": self.name, diff --git a/discord/commands/invokable.py b/discord/commands/invokable.py new file mode 100644 index 0000000000..2252efe56d --- /dev/null +++ b/discord/commands/invokable.py @@ -0,0 +1,586 @@ +from __future__ import annotations + +import asyncio +import datetime +import functools +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + List, + Optional, + TypeVar, + Union, +) + +from .. import utils +from ..errors import ( + ApplicationCommandError, + CheckFailure, + CommandError, + CommandInvokeError, + DisabledCommand, +) +from .cooldowns import BucketType, CooldownMapping, MaxConcurrency + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from ..cog import Cog + + P = ParamSpec("P") + + CallbackT = TypeVar("CallbackT") + ErrorT = TypeVar("ErrorT") + HookT = TypeVar("HookT") + ContextT = TypeVar("ContextT") + + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + MaybeCoro = Union[T, Coro[T]] + + Check = Union[ + Callable[[ContextT], MaybeCoro[bool]], + Callable[[Cog, ContextT], MaybeCoro[bool]] + ] + + +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, "__wrapped__"): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + + +def wrap_callback(coro): + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except CommandError: + raise + except asyncio.CancelledError: + return + except Exception as exc: + raise CommandInvokeError(exc) from exc + return ret + + return wrapped + + +def hooked_wrapped_callback(command: Invokable, ctx: ContextT, coro: CallbackT): + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except (ApplicationCommandError, CommandError): + ctx.command_failed = True + raise + except asyncio.CancelledError: + ctx.command_failed = True + return + except Exception as exc: + ctx.command_failed = True + raise CommandInvokeError(exc) from exc + finally: + if command._max_concurrency is not None: + await command._max_concurrency.release(ctx) + await command.call_after_hooks(ctx) + + return ret + + return wrapped + + +class Invokable: + checks: List[Check] + _buckets: CooldownMapping + _max_concurrency: Optional[MaxConcurrency] + on_error: Optional[ErrorT] + _before_invoke: Optional[HookT] + _after_invoke: Optional[HookT] + + def __init__(self, func: CallbackT, **kwargs): + self.module: Any = None + self.cog: Optional[Cog] + self.parent: Optional[Invokable] = kwargs.get("parent") + self.callback: CallbackT = func + + self.name: str = str(kwargs.get("name", func.__name__)) + self.enabled: bool = kwargs.get("enabled", True) + + # checks + if checks := getattr(func, "__commands_checks__", []): + checks.reverse() + + checks += kwargs.get("checks", []) # combine all the checks we find (kwargs or decorator) + self.checks = checks + + # cooldowns + cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) + + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets = cooldown + else: + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + + self._buckets = buckets + + # max concurrency + self._max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) + + # hooks + self._before_invoke = None + if hook := getattr(func, "__before_invoke__", None): + self.before_invoke(hook) + + self._after_invoke = None + if hook := getattr(func, "__after_invoke__", None): + self.after_invoke(hook) + + @property + def callback(self) -> CallbackT: + return self._callback + + @callback.setter + def callback(self, func: CallbackT) -> None: + if not asyncio.iscoroutinefunction(func): + raise TypeError("Callback must be a coroutine.") + + self._callback = func + unwrap = unwrap_function(func) + self.module = unwrap.__module__ + + @property + def cooldown(self): + return self._buckets._cooldown + + @property + def full_parent_name(self) -> Optional[str]: + """:class:`str`: Retrieves the fully qualified parent command name. + + This the base command name required to execute it. For example, + in ``/one two three`` the parent name would be ``one two``. + """ + if self.parent: + return self.parent.qualified_name + + @property + def qualified_name(self) -> str: + """:class:`str`: Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``?one two three`` the qualified name would be + ``one two three``. + """ + if not self.parent: + return self.name + + return f"{self.parent.qualified_name} {self.name}" + + def __str__(self) -> str: + return self.qualified_name + + async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): + """|coro| + + Calls the internal callback that the command holds. + + .. note:: + + This bypasses all mechanisms -- including checks, converters, + invoke hooks, cooldowns, etc. You must take care to pass + the proper arguments and types to this function. + + """ + if self.cog is not None: + return await self.callback(self.cog, ctx, *args, **kwargs) + return await self.callback(ctx, *args, **kwargs) + + def error(self, coro: ErrorT) -> ErrorT: + """A decorator that registers a coroutine as a local error handler. + + A local error handler is an :func:`.on_command_error` event limited to + a single command. However, the :func:`.on_command_error` is still + invoked afterwards as the catch-all. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The error handler must be a coroutine.") + + self.on_error = coro + return coro + + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the command has an error handler registered.""" + return hasattr(self, "on_error") + + def before_invoke(self, coro: HookT) -> HookT: + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.before_invoke` for more info. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the pre-invoke hook. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro: HookT) -> HookT: + """A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.after_invoke` for more info. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the post-invoke hook. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + async def can_run(self, ctx: ContextT) -> bool: + """|coro| + + Checks if the command can be executed by checking all the predicates + inside the :attr:`~Command.checks` attribute. This also checks whether the + command is disabled. + + .. versionchanged:: 1.3 + Checks whether the command is disabled or not + + Parameters + ----------- + ctx: :class:`.Context` + The ctx of the command currently being invoked. + + Raises + ------- + :class:`CommandError` + Any command error that was raised during a check call will be propagated + by this function. + + Returns + -------- + :class:`bool` + A boolean indicating if the command can be invoked. + """ + if not self.enabled: + raise DisabledCommand(f"{self.name} command is disabled") + + original = ctx.command + ctx.command = self + + try: + if not await ctx.bot.can_run(ctx): + raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") + + # I personally don't think parent checks should be + # run with the subcommand. It causes confusion, and + # nerfs control for a bit of reduced redundancy + # predicates = self.checks + # if self.parent is not None: + # # parent checks should be run first + # predicates = self.parent.checks + predicates + + if (cog := self.cog) and (local_check := cog._get_overridden_method(cog.cog_check)): + ret = await utils.maybe_coroutine(local_check, ctx) + if not ret: + return False + + predicates = self.checks + if not predicates: + # since we have no checks, then we just return True. + return True + + return await utils.async_all(predicate(ctx) for predicate in predicates) + finally: + ctx.command = original + + # depends on what to do with the application_command_error event + + # async def dispatch_error(self, ctx: ContextT, error: Exception) -> None: + # ctx.command_failed = True + # cog = self.cog + + # if coro := getattr(self, "on_error", None): + # injected = wrap_callback(coro) + # if cog is not None: + # await injected(cog, ctx, error) + # else: + # await injected(ctx, error) + + # try: + # if cog is not None: + # local = cog.__class__._get_overridden_method(cog.cog_command_error) + # if local is not None: + # wrapped = wrap_callback(local) + # await wrapped(ctx, error) + # finally: + # ctx.bot.dispatch("application_command_error", ctx, error) + + def add_check(self, func: Check) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`.check`. + + Parameters + ----------- + func: Callable + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + Parameters + ----------- + func: Callable + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + + def _prepare_cooldowns(self, ctx: ContextT): + if not self._buckets.valid: + return + + current = datetime.datetime.now().timestamp() + bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message + + if bucket: + retry_after = bucket.update_rate_limit(current) + + if retry_after: + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + + def is_on_cooldown(self, ctx: ContextT) -> bool: + """Checks whether the command is currently on cooldown. + + .. note:: + + This uses the current time instead of the interaction time. + + Parameters + ----------- + ctx: :class:`.ApplicationContext` + The invocation context to use when checking the command's cooldown status. + + Returns + -------- + :class:`bool` + A boolean indicating if the command is on cooldown. + """ + if not self._buckets.valid: + return False + + bucket = self._buckets.get_bucket(ctx) + current = utils.utcnow().timestamp() + return bucket.get_tokens(current) == 0 + + def reset_cooldown(self, ctx) -> None: + """Resets the cooldown on this command. + + Parameters + ----------- + ctx: :class:`.ApplicationContext` + The invocation context to reset the cooldown under. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message + bucket.reset() + + def get_cooldown_retry_after(self, ctx) -> float: + """Retrieves the amount of seconds before this command can be tried again. + + .. note:: + + This uses the current time instead of the interaction time. + + Parameters + ----------- + ctx: :class:`.ApplicationContext` + The invocation context to retrieve the cooldown from. + + Returns + -------- + :class:`float` + The amount of time left on this command's cooldown in seconds. + If this is ``0.0`` then the command isn't on cooldown. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) + current = utils.utcnow().timestamp() + return bucket.get_retry_after(current) + + return 0.0 + + async def call_before_hooks(self, ctx: ContextT) -> None: + # now that we're done preparing we can call the pre-command hooks + # first, call the command local hook: + cog = self.cog + if self._before_invoke is not None: + # should be cog if @commands.before_invoke is used + instance = getattr(self._before_invoke, "__self__", cog) + # __self__ only exists for methods, not functions + # however, if @command.before_invoke is used, it will be a function + if instance: + await self._before_invoke(instance, ctx) # type: ignore + else: + await self._before_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = cog.__class__._get_overridden_method(cog.cog_before_invoke) + if hook is not None: + await hook(ctx) + + # call the bot global hook if necessary + hook = ctx.bot._before_invoke + if hook is not None: + await hook(ctx) + + async def call_after_hooks(self, ctx: ContextT) -> None: + cog = self.cog + if self._after_invoke is not None: + instance = getattr(self._after_invoke, "__self__", cog) + if instance: + await self._after_invoke(instance, ctx) # type: ignore + else: + await self._after_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = Cog._get_overridden_method(cog.cog_after_invoke) + if hook is not None: + await hook(ctx) + + hook = ctx.bot._after_invoke + if hook is not None: + await hook(ctx) + + async def _parse_arguments(self, ctx: ContextT) -> None: + return + + async def prepare(self, ctx: ContextT) -> None: + ctx.command = self + + if not await self.can_run(ctx): + raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") + + if self._max_concurrency is not None: + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) # type: ignore + + try: + self._prepare_cooldowns(ctx) + await self._parse_arguments(ctx) + + await self.call_before_hooks(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) # type: ignore + raise + + async def invoke(self, ctx: ContextT) -> None: + await self.prepare(ctx) + + # terminate the invoked_subcommand chain. + # since we're in a regular command (and not a group) then + # the invoked subcommand is None. + # ctx.invoked_subcommand = None + # ctx.subcommand_passed = None + injected = hooked_wrapped_callback(self, ctx, self.callback) + await injected(*ctx.args, **ctx.kwargs) + + def copy(self): + """Creates a copy of this command. + + Returns + -------- + :class:`Invokable` + A new instance of this command. + """ + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) + + def _ensure_assignment_on_copy(self, other): + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + if self._buckets.valid and not other._buckets.valid: + other._buckets = self._buckets.copy() + if self._max_concurrency != other._max_concurrency: + # _max_concurrency won't be None at this point + other._max_concurrency = self._max_concurrency.copy() # type: ignore + + try: + other.on_error = self.on_error + except AttributeError: + pass + return other + + def _update_copy(self, kwargs: Dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() diff --git a/discord/errors.py b/discord/errors.py index c3263b5567..4c5479bdfb 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -38,6 +38,7 @@ _ResponseType = ClientResponse from .interactions import Interaction + from .commands.cooldowns import BucketType, Cooldown __all__ = ( "DiscordException", @@ -302,6 +303,65 @@ def __init__(self, interaction: Interaction): super().__init__("This interaction has already been responded to before") +# command errors + +class CommandError(DiscordException): + r"""The base exception type for all command related errors. + + This inherits from :exc:`discord.DiscordException`. + + This exception and exceptions inherited from it are handled + in a special way as they are caught and passed into a special event + from :class:`.Bot`\, :func:`.on_command_error`. + """ + + def __init__(self, message: Optional[str] = None, *args: Any) -> None: + if message is not None: + # clean-up @everyone and @here mentions + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") + super().__init__(m, *args) + else: + super().__init__(*args) + + +class ApplicationCommandError(CommandError): + r"""The base exception type for all application command related errors. + + This inherits from :exc:`DiscordException`. + + This exception and exceptions inherited from it are handled + in a special way as they are caught and passed into a special event + from :class:`.Bot`\, :func:`.on_command_error`. + """ + pass + + +class CheckFailure(CommandError): + """Exception raised when the predicates in :attr:`.Command.checks` have failed. + + This inherits from :exc:`CommandError` + """ + + pass + + +class ApplicationCommandInvokeError(ApplicationCommandError): + """Exception raised when the command being invoked raised an exception. + + This inherits from :exc:`ApplicationCommandError` + + Attributes + ----------- + original: :exc:`Exception` + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. + """ + + def __init__(self, e: Exception) -> None: + self.original: Exception = e + super().__init__(f"Application Command raised an exception: {e.__class__.__name__}: {e}") + + class ExtensionError(DiscordException): """Base exception for extension related errors. @@ -390,31 +450,56 @@ def __init__(self, name: str) -> None: super().__init__(msg, name=name) -class ApplicationCommandError(DiscordException): - r"""The base exception type for all application command related errors. +class MaxConcurrencyReached(CommandError): + """Exception raised when the command being invoked has reached its maximum concurrency. - This inherits from :exc:`DiscordException`. + This inherits from :exc:`CommandError`. - This exception and exceptions inherited from it are handled - in a special way as they are caught and passed into a special event - from :class:`.Bot`\, :func:`.on_command_error`. + Attributes + ------------ + number: :class:`int` + The maximum number of concurrent invokers allowed. + per: :class:`.BucketType` + The bucket type passed to the :func:`.max_concurrency` decorator. """ - pass + def __init__(self, number: int, per: BucketType) -> None: + self.number: int = number + self.per: BucketType = per + name = per.name + suffix = f"per {name}" if per.name != "default" else "globally" + plural = "%s times %s" if number > 1 else "%s time %s" + fmt = plural % (number, suffix) + super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") -class CheckFailure(ApplicationCommandError): - """Exception raised when the predicates in :attr:`.Command.checks` have failed. - This inherits from :exc:`ApplicationCommandError` +class CommandOnCooldown(CommandError): + """Exception raised when the command being invoked is on cooldown. + + This inherits from :exc:`CommandError` + + Attributes + ----------- + cooldown: :class:`.Cooldown` + A class with attributes ``rate`` and ``per`` similar to the + :func:`.cooldown` decorator. + type: :class:`BucketType` + The type associated with the cooldown. + retry_after: :class:`float` + The amount of seconds to wait before you can retry again. """ - pass + def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: + self.cooldown: Cooldown = cooldown + self.retry_after: float = retry_after + self.type: BucketType = type + super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s") -class ApplicationCommandInvokeError(ApplicationCommandError): +class CommandInvokeError(CommandError): """Exception raised when the command being invoked raised an exception. - This inherits from :exc:`ApplicationCommandError` + This inherits from :exc:`CommandError` Attributes ----------- @@ -425,4 +510,13 @@ class ApplicationCommandInvokeError(ApplicationCommandError): def __init__(self, e: Exception) -> None: self.original: Exception = e - super().__init__(f"Application Command raised an exception: {e.__class__.__name__}: {e}") + super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") + + +class DisabledCommand(CommandError): + """Exception raised when the command being invoked is disabled. + + This inherits from :exc:`CommandError` + """ + + pass diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index c22121c63e..8d16af0014 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -23,20 +23,7 @@ DEALINGS IN THE SOFTWARE. """ -from __future__ import annotations - -import asyncio -import time -from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Optional, Type, TypeVar - -from discord.enums import Enum - -from ...abc import PrivateChannel -from .errors import MaxConcurrencyReached - -if TYPE_CHECKING: - from ...message import Message +from ...commands.cooldowns import * __all__ = ( "BucketType", @@ -45,347 +32,3 @@ "DynamicCooldownMapping", "MaxConcurrency", ) - -C = TypeVar("C", bound="CooldownMapping") -MC = TypeVar("MC", bound="MaxConcurrency") - - -class BucketType(Enum): - default = 0 - user = 1 - guild = 2 - channel = 3 - member = 4 - category = 5 - role = 6 - - def get_key(self, msg: Message) -> Any: - if self is BucketType.user: - return msg.author.id - elif self is BucketType.guild: - return (msg.guild or msg.author).id - elif self is BucketType.channel: - return msg.channel.id - elif self is BucketType.member: - return (msg.guild and msg.guild.id), msg.author.id - elif self is BucketType.category: - return (msg.channel.category or msg.channel).id # type: ignore - elif self is BucketType.role: - # we return the channel id of a private-channel as there are only roles in guilds - # and that yields the same result as for a guild with only the @everyone role - # NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are - # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore - - def __call__(self, msg: Message) -> Any: - return self.get_key(msg) - - -class Cooldown: - """Represents a cooldown for a command. - - Attributes - ----------- - rate: :class:`int` - The total number of tokens available per :attr:`per` seconds. - per: :class:`float` - The length of the cooldown period in seconds. - """ - - __slots__ = ("rate", "per", "_window", "_tokens", "_last") - - def __init__(self, rate: float, per: float) -> None: - self.rate: int = int(rate) - self.per: float = float(per) - self._window: float = 0.0 - self._tokens: int = self.rate - self._last: float = 0.0 - - def get_tokens(self, current: Optional[float] = None) -> int: - """Returns the number of available tokens before rate limiting is applied. - - Parameters - ------------ - current: Optional[:class:`float`] - The time in seconds since Unix epoch to calculate tokens at. - If not supplied then :func:`time.time()` is used. - - Returns - -------- - :class:`int` - The number of tokens available before the cooldown is to be applied. - """ - if not current: - current = time.time() - - tokens = self._tokens - - if current > self._window + self.per: - tokens = self.rate - return tokens - - def get_retry_after(self, current: Optional[float] = None) -> float: - """Returns the time in seconds until the cooldown will be reset. - - Parameters - ------------- - current: Optional[:class:`float`] - The current time in seconds since Unix epoch. - If not supplied, then :func:`time.time()` is used. - - Returns - ------- - :class:`float` - The number of seconds to wait before this cooldown will be reset. - """ - current = current or time.time() - tokens = self.get_tokens(current) - - if tokens == 0: - return self.per - (current - self._window) - - return 0.0 - - def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: - """Updates the cooldown rate limit. - - Parameters - ------------- - current: Optional[:class:`float`] - The time in seconds since Unix epoch to update the rate limit at. - If not supplied, then :func:`time.time()` is used. - - Returns - ------- - Optional[:class:`float`] - The retry-after time in seconds if rate limited. - """ - current = current or time.time() - self._last = current - - self._tokens = self.get_tokens(current) - - # first token used means that we start a new rate limit window - if self._tokens == self.rate: - self._window = current - - # check if we are rate limited - if self._tokens == 0: - return self.per - (current - self._window) - - # we're not so decrement our tokens - self._tokens -= 1 - - def reset(self) -> None: - """Reset the cooldown to its initial state.""" - self._tokens = self.rate - self._last = 0.0 - - def copy(self) -> Cooldown: - """Creates a copy of this cooldown. - - Returns - -------- - :class:`Cooldown` - A new instance of this cooldown. - """ - return Cooldown(self.rate, self.per) - - def __repr__(self) -> str: - return f"" - - -class CooldownMapping: - def __init__( - self, - original: Optional[Cooldown], - type: Callable[[Message], Any], - ) -> None: - if not callable(type): - raise TypeError("Cooldown type must be a BucketType or callable") - - self._cache: Dict[Any, Cooldown] = {} - self._cooldown: Optional[Cooldown] = original - self._type: Callable[[Message], Any] = type - - def copy(self) -> CooldownMapping: - ret = CooldownMapping(self._cooldown, self._type) - ret._cache = self._cache.copy() - return ret - - @property - def valid(self) -> bool: - return self._cooldown is not None - - @property - def type(self) -> Callable[[Message], Any]: - return self._type - - @classmethod - def from_cooldown(cls: Type[C], rate, per, type) -> C: - return cls(Cooldown(rate, per), type) - - def _bucket_key(self, msg: Message) -> Any: - return self._type(msg) - - def _verify_cache_integrity(self, current: Optional[float] = None) -> None: - # we want to delete all cache objects that haven't been used - # in a cooldown window. e.g. if we have a command that has a - # cooldown of 60s, and it has not been used in 60s then that key should be deleted - current = current or time.time() - dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per] - for k in dead_keys: - del self._cache[k] - - def create_bucket(self, message: Message) -> Cooldown: - return self._cooldown.copy() # type: ignore - - def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown: - if self._type is BucketType.default: - return self._cooldown # type: ignore - - self._verify_cache_integrity(current) - key = self._bucket_key(message) - if key not in self._cache: - bucket = self.create_bucket(message) - if bucket is not None: - self._cache[key] = bucket - else: - bucket = self._cache[key] - - return bucket - - def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: - bucket = self.get_bucket(message, current) - return bucket.update_rate_limit(current) - - -class DynamicCooldownMapping(CooldownMapping): - def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None: - super().__init__(None, type) - self._factory: Callable[[Message], Cooldown] = factory - - def copy(self) -> DynamicCooldownMapping: - ret = DynamicCooldownMapping(self._factory, self._type) - ret._cache = self._cache.copy() - return ret - - @property - def valid(self) -> bool: - return True - - def create_bucket(self, message: Message) -> Cooldown: - return self._factory(message) - - -class _Semaphore: - """This class is a version of a semaphore. - - If you're wondering why asyncio.Semaphore isn't being used, - it's because it doesn't expose the internal value. This internal - value is necessary because I need to support both `wait=True` and - `wait=False`. - - An asyncio.Queue could have been used to do this as well -- but it is - not as inefficient since internally that uses two queues and is a bit - overkill for what is basically a counter. - """ - - __slots__ = ("value", "loop", "_waiters") - - def __init__(self, number: int) -> None: - self.value: int = number - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - self._waiters: Deque[asyncio.Future] = deque() - - def __repr__(self) -> str: - return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>" - - def locked(self) -> bool: - return self.value == 0 - - def is_active(self) -> bool: - return len(self._waiters) > 0 - - def wake_up(self) -> None: - while self._waiters: - future = self._waiters.popleft() - if not future.done(): - future.set_result(None) - return - - async def acquire(self, *, wait: bool = False) -> bool: - if not wait and self.value <= 0: - # signal that we're not acquiring - return False - - while self.value <= 0: - future = self.loop.create_future() - self._waiters.append(future) - try: - await future - except: - future.cancel() - if self.value > 0 and not future.cancelled(): - self.wake_up() - raise - - self.value -= 1 - return True - - def release(self) -> None: - self.value += 1 - self.wake_up() - - -class MaxConcurrency: - __slots__ = ("number", "per", "wait", "_mapping") - - def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: - self._mapping: Dict[Any, _Semaphore] = {} - self.per: BucketType = per - self.number: int = number - self.wait: bool = wait - - if number <= 0: - raise ValueError("max_concurrency 'number' cannot be less than 1") - - if not isinstance(per, BucketType): - raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") - - def copy(self: MC) -> MC: - return self.__class__(self.number, per=self.per, wait=self.wait) - - def __repr__(self) -> str: - return f"" - - def get_key(self, message: Message) -> Any: - return self.per.get_key(message) - - async def acquire(self, message: Message) -> None: - key = self.get_key(message) - - try: - sem = self._mapping[key] - except KeyError: - self._mapping[key] = sem = _Semaphore(self.number) - - acquired = await sem.acquire(wait=self.wait) - if not acquired: - raise MaxConcurrencyReached(self.number, self.per) - - async def release(self, message: Message) -> None: - # Technically there's no reason for this function to be async - # But it might be more useful in the future - key = self.get_key(message) - - try: - sem = self._mapping[key] - except KeyError: - # ...? peculiar - return - else: - sem.release() - - if sem.value >= self.number and not sem.is_active(): - del self._mapping[key] diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 1b463a460f..5b29871726 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -27,7 +27,15 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union -from discord.errors import ClientException, DiscordException +from discord.errors import ( + ClientException, + CommandError, + CheckFailure, + CommandOnCooldown, + CommandInvokeError, + MaxConcurrencyReached, + DisabledCommand +) if TYPE_CHECKING: from inspect import Parameter @@ -38,7 +46,6 @@ from .context import Context from .converter import Converter - from .cooldowns import BucketType, Cooldown from .flags import Flag @@ -97,23 +104,7 @@ ) -class CommandError(DiscordException): - r"""The base exception type for all command related errors. - - This inherits from :exc:`discord.DiscordException`. - - This exception and exceptions inherited from it are handled - in a special way as they are caught and passed into a special event - from :class:`.Bot`\, :func:`.on_command_error`. - """ - - def __init__(self, message: Optional[str] = None, *args: Any) -> None: - if message is not None: - # clean-up @everyone and @here mentions - m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") - super().__init__(m, *args) - else: - super().__init__(*args) +from ...errors import CommandError class ConversionError(CommandError): @@ -195,15 +186,6 @@ class BadArgument(UserInputError): pass -class CheckFailure(CommandError): - """Exception raised when the predicates in :attr:`.Command.checks` have failed. - - This inherits from :exc:`CommandError` - """ - - pass - - class CheckAnyFailure(CheckFailure): """Exception raised when all predicates in :func:`check_any` fail. @@ -529,78 +511,6 @@ def __init__(self, argument: str) -> None: super().__init__(f"{argument} is not a recognised boolean option") -class DisabledCommand(CommandError): - """Exception raised when the command being invoked is disabled. - - This inherits from :exc:`CommandError` - """ - - pass - - -class CommandInvokeError(CommandError): - """Exception raised when the command being invoked raised an exception. - - This inherits from :exc:`CommandError` - - Attributes - ----------- - original: :exc:`Exception` - The original exception that was raised. You can also get this via - the ``__cause__`` attribute. - """ - - def __init__(self, e: Exception) -> None: - self.original: Exception = e - super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") - - -class CommandOnCooldown(CommandError): - """Exception raised when the command being invoked is on cooldown. - - This inherits from :exc:`CommandError` - - Attributes - ----------- - cooldown: :class:`.Cooldown` - A class with attributes ``rate`` and ``per`` similar to the - :func:`.cooldown` decorator. - type: :class:`BucketType` - The type associated with the cooldown. - retry_after: :class:`float` - The amount of seconds to wait before you can retry again. - """ - - def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: - self.cooldown: Cooldown = cooldown - self.retry_after: float = retry_after - self.type: BucketType = type - super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s") - - -class MaxConcurrencyReached(CommandError): - """Exception raised when the command being invoked has reached its maximum concurrency. - - This inherits from :exc:`CommandError`. - - Attributes - ------------ - number: :class:`int` - The maximum number of concurrent invokers allowed. - per: :class:`.BucketType` - The bucket type passed to the :func:`.max_concurrency` decorator. - """ - - def __init__(self, number: int, per: BucketType) -> None: - self.number: int = number - self.per: BucketType = per - name = per.name - suffix = f"per {name}" if per.name != "default" else "globally" - plural = "%s times %s" if number > 1 else "%s time %s" - fmt = plural % (number, suffix) - super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") - - class MissingRole(CheckFailure): """Exception raised when the command invoker lacks a role to run a command. From 2fcd736946a3d6c8c1492d511b850fcfbda600aa Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 17:08:47 -0400 Subject: [PATCH 05/54] Spelling is close enough --- .github/workflows/codespell.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml index 6321a51781..209f32fa41 100644 --- a/.github/workflows/codespell.yml +++ b/.github/workflows/codespell.yml @@ -8,4 +8,4 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - run: pip install codespell==2.1.0 - - run: codespell --ignore-words-list="groupt,nd,ot,ro,falsy,BU" --exclude-file=".github/workflows/codespell.yml" + - run: codespell --ignore-words-list="groupt,nd,ot,ro,falsy,BU,invokable" --exclude-file=".github/workflows/codespell.yml" From c2a41e4f206ef89b4ce9ef88fa4ea3abcfd69ecd Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 17:32:46 -0400 Subject: [PATCH 06/54] Integrate ext.commands + typehinting fixes --- discord/commands/__init__.py | 1 + discord/commands/core.py | 10 +- discord/commands/invokable.py | 32 +- discord/ext/commands/core.py | 531 ++-------------------------------- 4 files changed, 48 insertions(+), 526 deletions(-) diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py index 1813faf3dc..3b6c5b70b9 100644 --- a/discord/commands/__init__.py +++ b/discord/commands/__init__.py @@ -27,3 +27,4 @@ from .core import * from .options import * from .permissions import * +from .invokable import Invokable, _BaseCommand diff --git a/discord/commands/core.py b/discord/commands/core.py index 2bf3d84978..b8deb7a2e2 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -26,7 +26,6 @@ from __future__ import annotations import asyncio -import datetime import functools import inspect import re @@ -53,7 +52,6 @@ from ..errors import ( ApplicationCommandError, ApplicationCommandInvokeError, - CheckFailure, ClientException, ValidationError, ) @@ -63,8 +61,8 @@ from ..role import Role from ..threads import Thread from ..user import User -from .invokable import Invokable -from ..utils import async_all, find, maybe_coroutine, MISSING +from .invokable import Invokable, _BaseCommand +from ..utils import find, MISSING from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice @@ -168,10 +166,6 @@ def _validate_descriptions(obj): validate_chat_input_description(string, locale=locale) -class _BaseCommand: - __slots__ = () - - class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): __original_kwargs__: Dict[str, Any] cog = None diff --git a/discord/commands/invokable.py b/discord/commands/invokable.py index 2252efe56d..30a39f4576 100644 --- a/discord/commands/invokable.py +++ b/discord/commands/invokable.py @@ -34,7 +34,6 @@ CallbackT = TypeVar("CallbackT") ErrorT = TypeVar("ErrorT") - HookT = TypeVar("HookT") ContextT = TypeVar("ContextT") T = TypeVar("T") @@ -42,8 +41,12 @@ MaybeCoro = Union[T, Coro[T]] Check = Union[ - Callable[[ContextT], MaybeCoro[bool]], - Callable[[Cog, ContextT], MaybeCoro[bool]] + Callable[[Cog, ContextT], MaybeCoro[bool]], # TODO: replace with stardized context superclass + Callable[[ContextT], MaybeCoro[bool]], # as well as for the others + ] + Hook = Union[ + Callable[[Cog, ContextT], Coro[Any]], + Callable[[ContextT], Coro[Any]] ] @@ -98,13 +101,12 @@ async def wrapped(*args, **kwargs): return wrapped +class _BaseCommand: + __slots__ = () + + class Invokable: - checks: List[Check] - _buckets: CooldownMapping - _max_concurrency: Optional[MaxConcurrency] on_error: Optional[ErrorT] - _before_invoke: Optional[HookT] - _after_invoke: Optional[HookT] def __init__(self, func: CallbackT, **kwargs): self.module: Any = None @@ -120,7 +122,7 @@ def __init__(self, func: CallbackT, **kwargs): checks.reverse() checks += kwargs.get("checks", []) # combine all the checks we find (kwargs or decorator) - self.checks = checks + self.checks: List[Check] = checks # cooldowns cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) @@ -132,17 +134,17 @@ def __init__(self, func: CallbackT, **kwargs): else: raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") - self._buckets = buckets + self._buckets: CooldownMapping = buckets # max concurrency - self._max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) + self._max_concurrency: Optional[MaxConcurrency] = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) # hooks - self._before_invoke = None + self._before_invoke: Optional[Hook] = None if hook := getattr(func, "__before_invoke__", None): self.before_invoke(hook) - self._after_invoke = None + self._after_invoke: Optional[Hook] = None if hook := getattr(func, "__after_invoke__", None): self.after_invoke(hook) @@ -232,7 +234,7 @@ def has_error_handler(self) -> bool: """:class:`bool`: Checks whether the command has an error handler registered.""" return hasattr(self, "on_error") - def before_invoke(self, coro: HookT) -> HookT: + def before_invoke(self, coro: Hook) -> Hook: """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is @@ -259,7 +261,7 @@ def before_invoke(self, coro: HookT) -> HookT: self._before_invoke = coro return coro - def after_invoke(self, coro: HookT) -> HookT: + def after_invoke(self, coro: Hook) -> Hook: """A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 053aae3d38..b3ff407907 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -57,6 +57,7 @@ slash_command, user_command, ) +from ...commands.invokable import Invokable from ...errors import * from .cog import Cog from .context import Context @@ -126,6 +127,17 @@ else: P = TypeVar("P") +CallbackT = Union[ + Callable[ + [Concatenate[CogT, ContextT, P]], + Coro[T] + ], + Callable[ + [Concatenate[ContextT, P]], + Coro[T] + ], +] + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: partial = functools.partial @@ -221,7 +233,7 @@ def __setitem__(self, k, v): super().__setitem__(k.casefold(), v) -class Command(_BaseCommand, Generic[CogT, P, T]): +class Command(Invokable, _BaseCommand, Generic[CogT, P, T]): r"""A class that implements the protocol for a bot text command. These are not created manually, instead they are created via the @@ -303,6 +315,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): .. versionadded:: 2.0 """ __original_kwargs__: Dict[str, Any] + _callback: CallbackT def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT: # if you're wondering why this is done, it's because we need to ensure @@ -322,28 +335,10 @@ def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT: def __init__( self, - func: Union[ - Callable[ - [Concatenate[CogT, ContextT, P]], - Coro[T] - ], - Callable[ - [Concatenate[ContextT, P]], - Coro[T] - ], - ], + func: CallbackT, **kwargs: Any, ): - if not asyncio.iscoroutinefunction(func): - raise TypeError("Callback must be a coroutine.") - - name = kwargs.get("name") or func.__name__ - if not isinstance(name, str): - raise TypeError("Name of a command must be a string.") - self.name: str = name - - self.callback = func - self.enabled: bool = kwargs.get("enabled", True) + super().__init__(func, **kwargs) help_doc = kwargs.get("help") if help_doc is not None: @@ -367,90 +362,30 @@ def __init__( self.description: str = inspect.cleandoc(kwargs.get("description", "")) self.hidden: bool = kwargs.get("hidden", False) - try: - checks = func.__commands_checks__ - checks.reverse() - except AttributeError: - checks = kwargs.get("checks", []) - - self.checks: List[Check] = checks - - try: - cooldown = func.__commands_cooldown__ - except AttributeError: - cooldown = kwargs.get("cooldown") - - if cooldown is None: - buckets = CooldownMapping(cooldown, BucketType.default) - elif isinstance(cooldown, CooldownMapping): - buckets = cooldown - else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") - self._buckets: CooldownMapping = buckets - - try: - max_concurrency = func.__commands_max_concurrency__ - except AttributeError: - max_concurrency = kwargs.get("max_concurrency") - - self._max_concurrency: Optional[MaxConcurrency] = max_concurrency - self.require_var_positional: bool = kwargs.get("require_var_positional", False) self.ignore_extra: bool = kwargs.get("ignore_extra", True) - self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) - self.cog: Optional[CogT] = None + # TODO: maybe??? + # self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) + # TODO: typing + # self.cog: Optional[CogT] = None # bandaid for the fact that sometimes parent can be the bot instance parent = kwargs.get("parent") self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore - self._before_invoke: Optional[Hook] = None - try: - before_invoke = func.__before_invoke__ - except AttributeError: - pass - else: - self.before_invoke(before_invoke) - - self._after_invoke: Optional[Hook] = None - try: - after_invoke = func.__after_invoke__ - except AttributeError: - pass - else: - self.after_invoke(after_invoke) - @property def callback( self, - ) -> Union[ - Callable[ - [Concatenate[CogT, Context, P]], - Coro[T] - ], - Callable[ - [Concatenate[Context, P]], - Coro[T] - ], - ]: + ) -> CallbackT: return self._callback @callback.setter def callback( self, - function: Union[ - Callable[ - [Concatenate[CogT, Context, P]], - Coro[T] - ], - Callable[ - [Concatenate[Context, P]], - Coro[T] - ], - ], + func: CallbackT ) -> None: - self._callback = function - unwrap = unwrap_function(function) + self._callback = func + unwrap = unwrap_function(func) self.module = unwrap.__module__ try: @@ -458,41 +393,7 @@ def callback( except AttributeError: globalns = {} - self.params = get_signature_parameters(function, globalns) - - def add_check(self, func: Check) -> None: - """Adds a check to the command. - - This is the non-decorator interface to :func:`.check`. - - .. versionadded:: 1.3 - - Parameters - ----------- - func - The function that will be used as a check. - """ - - self.checks.append(func) - - def remove_check(self, func: Check) -> None: - """Removes a check from the command. - - This function is idempotent and will not raise an exception - if the function is not in the command's checks. - - .. versionadded:: 1.3 - - Parameters - ----------- - func - The function to remove from the checks. - """ - - try: - self.checks.remove(func) - except ValueError: - pass + self.params = get_signature_parameters(func, globalns) def update(self, **kwargs: Any) -> None: """Updates :class:`Command` instance with updated attribute. @@ -503,61 +404,6 @@ def update(self, **kwargs: Any) -> None: """ self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) - async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T: - """|coro| - - Calls the internal callback that the command holds. - - .. note:: - - This bypasses all mechanisms -- including checks, converters, - invoke hooks, cooldowns, etc. You must take care to pass - the proper arguments and types to this function. - - .. versionadded:: 1.3 - """ - if self.cog is not None: - return await self.callback(self.cog, context, *args, **kwargs) # type: ignore - else: - return await self.callback(context, *args, **kwargs) # type: ignore - - def _ensure_assignment_on_copy(self, other: CommandT) -> CommandT: - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - if self._buckets.valid and not other._buckets.valid: - other._buckets = self._buckets.copy() - if self._max_concurrency != other._max_concurrency: - # _max_concurrency won't be None at this point - other._max_concurrency = self._max_concurrency.copy() # type: ignore - - try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def copy(self: CommandT) -> CommandT: - """Creates a copy of this command. - - Returns - -------- - :class:`Command` - A new instance of this command. - """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) - - def _update_copy(self: CommandT, kwargs: Dict[str, Any]) -> CommandT: - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - async def dispatch_error(self, ctx: Context, error: Exception) -> None: ctx.command_failed = True cog = self.cog @@ -686,22 +532,6 @@ def clean_params(self) -> Dict[str, inspect.Parameter]: return result - @property - def full_parent_name(self) -> str: - """:class:`str`: Retrieves the fully qualified parent command name. - - This the base command name required to execute it. For example, - in ``?one two three`` the parent name would be ``one two``. - """ - entries = [] - command = self - # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO - while command.parent is not None: # type: ignore - command = command.parent # type: ignore - entries.append(command.name) # type: ignore - - return " ".join(reversed(entries)) - @property def parents(self) -> List[Group]: """List[:class:`Group`]: Retrieves the parents of this command. @@ -732,24 +562,6 @@ def root_parent(self) -> Optional[Group]: return None return self.parents[-1] - @property - def qualified_name(self) -> str: - """:class:`str`: Retrieves the fully qualified command name. - - This is the full parent name with the command name as well. - For example, in ``?one two three`` the qualified name would be - ``one two three``. - """ - - parent = self.full_parent_name - if parent: - return f"{parent} {self.name}" - else: - return self.name - - def __str__(self) -> str: - return self.qualified_name - async def _parse_arguments(self, ctx: Context) -> None: ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = {} @@ -800,156 +612,7 @@ async def _parse_arguments(self, ctx: Context) -> None: if not self.ignore_extra and not view.eof: raise TooManyArguments(f"Too many arguments passed to {self.qualified_name}") - async def call_before_hooks(self, ctx: Context) -> None: - # now that we're done preparing we can call the pre-command hooks - # first, call the command local hook: - cog = self.cog - if self._before_invoke is not None: - # should be cog if @commands.before_invoke is used - instance = getattr(self._before_invoke, "__self__", cog) - # __self__ only exists for methods, not functions - # however, if @command.before_invoke is used, it will be a function - if instance: - await self._before_invoke(instance, ctx) # type: ignore - else: - await self._before_invoke(ctx) # type: ignore - - # call the cog local hook if applicable: - if cog is not None: - hook = Cog._get_overridden_method(cog.cog_before_invoke) - if hook is not None: - await hook(ctx) - - # call the bot global hook if necessary - hook = ctx.bot._before_invoke - if hook is not None: - await hook(ctx) - - async def call_after_hooks(self, ctx: Context) -> None: - cog = self.cog - if self._after_invoke is not None: - instance = getattr(self._after_invoke, "__self__", cog) - if instance: - await self._after_invoke(instance, ctx) # type: ignore - else: - await self._after_invoke(ctx) # type: ignore - - # call the cog local hook if applicable: - if cog is not None: - hook = Cog._get_overridden_method(cog.cog_after_invoke) - if hook is not None: - await hook(ctx) - - hook = ctx.bot._after_invoke - if hook is not None: - await hook(ctx) - - def _prepare_cooldowns(self, ctx: Context) -> None: - if self._buckets.valid: - dt = ctx.message.edited_at or ctx.message.created_at - current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() - bucket = self._buckets.get_bucket(ctx.message, current) - if bucket is not None: - retry_after = bucket.update_rate_limit(current) - if retry_after: - raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - - async def prepare(self, ctx: Context) -> None: - ctx.command = self - - if not await self.can_run(ctx): - raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") - - if self._max_concurrency is not None: - # For this application, context can be duck-typed as a Message - await self._max_concurrency.acquire(ctx) # type: ignore - - try: - if self.cooldown_after_parsing: - await self._parse_arguments(ctx) - self._prepare_cooldowns(ctx) - else: - self._prepare_cooldowns(ctx) - await self._parse_arguments(ctx) - - await self.call_before_hooks(ctx) - except: - if self._max_concurrency is not None: - await self._max_concurrency.release(ctx) # type: ignore - raise - - @property - def cooldown(self) -> Optional[Cooldown]: - return self._buckets._cooldown - - def is_on_cooldown(self, ctx: Context) -> bool: - """Checks whether the command is currently on cooldown. - - Parameters - ----------- - ctx: :class:`.Context` - The invocation context to use when checking the command's cooldown status. - - Returns - -------- - :class:`bool` - A boolean indicating if the command is on cooldown. - """ - if not self._buckets.valid: - return False - - bucket = self._buckets.get_bucket(ctx.message) - dt = ctx.message.edited_at or ctx.message.created_at - current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() - return bucket.get_tokens(current) == 0 - - def reset_cooldown(self, ctx: Context) -> None: - """Resets the cooldown on this command. - - Parameters - ----------- - ctx: :class:`.Context` - The invocation context to reset the cooldown under. - """ - if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx.message) - bucket.reset() - - def get_cooldown_retry_after(self, ctx: Context) -> float: - """Retrieves the amount of seconds before this command can be tried again. - - .. versionadded:: 1.4 - - Parameters - ----------- - ctx: :class:`.Context` - The invocation context to retrieve the cooldown from. - - Returns - -------- - :class:`float` - The amount of time left on this command's cooldown in seconds. - If this is ``0.0`` then the command isn't on cooldown. - """ - if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx.message) - dt = ctx.message.edited_at or ctx.message.created_at - current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() - return bucket.get_retry_after(current) - - return 0.0 - - async def invoke(self, ctx: Context) -> None: - await self.prepare(ctx) - - # terminate the invoked_subcommand chain. - # since we're in a regular command (and not a group) then - # the invoked subcommand is None. - ctx.invoked_subcommand = None - ctx.subcommand_passed = None - injected = hooked_wrapped_callback(self, ctx, self.callback) - await injected(*ctx.args, **ctx.kwargs) - + # TODO: async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: ctx.command = self await self._parse_arguments(ctx) @@ -967,91 +630,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: if call_hooks: await self.call_after_hooks(ctx) - def error(self, coro: ErrorT) -> ErrorT: - """A decorator that registers a coroutine as a local error handler. - - A local error handler is an :func:`.on_command_error` event limited to - a single command. However, the :func:`.on_command_error` is still - invoked afterwards as the catch-all. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the local error handler. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The error handler must be a coroutine.") - - self.on_error: Error = coro - return coro - - def has_error_handler(self) -> bool: - """:class:`bool`: Checks whether the command has an error handler registered. - - .. versionadded:: 1.7 - """ - return hasattr(self, "on_error") - - def before_invoke(self, coro: HookT) -> HookT: - """A decorator that registers a coroutine as a pre-invoke hook. - - A pre-invoke hook is called directly before the command is - called. This makes it a useful function to set up database - connections or any type of set up required. - - This pre-invoke hook takes a sole parameter, a :class:`.Context`. - - See :meth:`.Bot.before_invoke` for more info. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the pre-invoke hook. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The pre-invoke hook must be a coroutine.") - - self._before_invoke = coro - return coro - - def after_invoke(self, coro: HookT) -> HookT: - """A decorator that registers a coroutine as a post-invoke hook. - - A post-invoke hook is called directly after the command is - called. This makes it a useful function to clean-up database - connections or any type of clean up required. - - This post-invoke hook takes a sole parameter, a :class:`.Context`. - - See :meth:`.Bot.after_invoke` for more info. - - Parameters - ----------- - coro: :ref:`coroutine ` - The coroutine to register as the post-invoke hook. - - Raises - ------- - TypeError - The coroutine passed is not actually a coroutine. - """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError("The post-invoke hook must be a coroutine.") - - self._after_invoke = coro - return coro - + # TODO: @property def cog_name(self) -> Optional[str]: """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" @@ -1131,64 +710,10 @@ def signature(self) -> str: return " ".join(result) - async def can_run(self, ctx: Context) -> bool: - """|coro| - - Checks if the command can be executed by checking all the predicates - inside the :attr:`~Command.checks` attribute. This also checks whether the - command is disabled. - - .. versionchanged:: 1.3 - Checks whether the command is disabled or not - - Parameters - ----------- - ctx: :class:`.Context` - The ctx of the command currently being invoked. - - Raises - ------- - :class:`CommandError` - Any command error that was raised during a check call will be propagated - by this function. - - Returns - -------- - :class:`bool` - A boolean indicating if the command can be invoked. - """ - - if not self.enabled: - raise DisabledCommand(f"{self.name} command is disabled") - - original = ctx.command - ctx.command = self - - try: - if not await ctx.bot.can_run(ctx): - raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") - - cog = self.cog - if cog is not None: - local_check = Cog._get_overridden_method(cog.cog_check) - if local_check is not None: - ret = await discord.utils.maybe_coroutine(local_check, ctx) - if not ret: - return False - - predicates = self.checks - if not predicates: - # since we have no checks, then we just return True. - return True - - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore - finally: - ctx.command = original - def _set_cog(self, cog): self.cog = cog - +# TODO: This is a mess class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. From c15d97a34a1e75cd6fed080f125833cb8e1f2f6d Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 17:46:35 -0400 Subject: [PATCH 07/54] New baseline methods --- discord/commands/invokable.py | 14 +++++++++++++- discord/ext/commands/core.py | 15 --------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/discord/commands/invokable.py b/discord/commands/invokable.py index 30a39f4576..c49a6eb6df 100644 --- a/discord/commands/invokable.py +++ b/discord/commands/invokable.py @@ -111,7 +111,7 @@ class Invokable: def __init__(self, func: CallbackT, **kwargs): self.module: Any = None self.cog: Optional[Cog] - self.parent: Optional[Invokable] = kwargs.get("parent") + self.parent: Optional[Invokable] = (parent := kwargs.get("parent")) if isinstance(parent, _BaseCommand) else None self.callback: CallbackT = func self.name: str = str(kwargs.get("name", func.__name__)) @@ -188,6 +188,11 @@ def qualified_name(self) -> str: return f"{self.parent.qualified_name} {self.name}" + @property + def cog_name(self) -> Optional[str]: + """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" + return type(self.cog).__cog_name__ if self.cog is not None else None + def __str__(self) -> str: return self.qualified_name @@ -207,6 +212,13 @@ async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): return await self.callback(self.cog, ctx, *args, **kwargs) return await self.callback(ctx, *args, **kwargs) + def update(self, **kwargs: Any) -> None: + """Updates the :class:`Command` instance with updated attribute. + + Similar to creating a new instance except it updates the current. + """ + self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) + def error(self, coro: ErrorT) -> ErrorT: """A decorator that registers a coroutine as a local error handler. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index b3ff407907..1c6a6bd0a9 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -395,15 +395,6 @@ def callback( self.params = get_signature_parameters(func, globalns) - def update(self, **kwargs: Any) -> None: - """Updates :class:`Command` instance with updated attribute. - - This works similarly to the :func:`.command` decorator in terms - of parameters in that they are passed to the :class:`Command` or - subclass constructors, sans the name and callback. - """ - self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) - async def dispatch_error(self, ctx: Context, error: Exception) -> None: ctx.command_failed = True cog = self.cog @@ -630,12 +621,6 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: if call_hooks: await self.call_after_hooks(ctx) - # TODO: - @property - def cog_name(self) -> Optional[str]: - """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" - return type(self.cog).__cog_name__ if self.cog is not None else None - @property def short_doc(self) -> str: """:class:`str`: Gets the "short" documentation of a command. From 3ce64b72cda68659133f6a993e3100309be2566c Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 27 Aug 2022 19:10:25 -0400 Subject: [PATCH 08/54] Some cleanup before standard context impl --- discord/bot.py | 1 - discord/commands/invokable.py | 29 +++++++++++++++++++---------- discord/ext/commands/bot.py | 4 +--- discord/ext/commands/context.py | 1 - discord/ext/commands/core.py | 1 - discord/ext/commands/errors.py | 3 --- discord/scheduled_events.py | 3 +-- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index 2354806a15..8b23ea40a8 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -1069,7 +1069,6 @@ async def on_application_command_error(self, context: ApplicationContext, except traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) # global check registration - # TODO: Remove these from commands.Bot def check(self, func): """A decorator that adds a global check to the bot. A global check is similar to a :func:`.check` that is diff --git a/discord/commands/invokable.py b/discord/commands/invokable.py index c49a6eb6df..5d1ae25572 100644 --- a/discord/commands/invokable.py +++ b/discord/commands/invokable.py @@ -5,14 +5,15 @@ import functools from typing import ( TYPE_CHECKING, - Any, + Generic, + TypeVar, Callable, Coroutine, - Dict, - List, Optional, - TypeVar, Union, + Any, + Dict, + List, ) from .. import utils @@ -33,8 +34,7 @@ P = ParamSpec("P") CallbackT = TypeVar("CallbackT") - ErrorT = TypeVar("ErrorT") - ContextT = TypeVar("ContextT") + ContextT = TypeVar("ContextT", bound="Context") # TODO: yes T = TypeVar("T") Coro = Coroutine[Any, Any, T] @@ -44,10 +44,19 @@ Callable[[Cog, ContextT], MaybeCoro[bool]], # TODO: replace with stardized context superclass Callable[[ContextT], MaybeCoro[bool]], # as well as for the others ] + + Error = Union[ + Callable[[Cog, "Context[Any]", CommandError], Coro[Any]], + Callable[["Context[Any]", CommandError], Coro[Any]], + ] + ErrorT = TypeVar("ErrorT", bound="Error") + Hook = Union[ Callable[[Cog, ContextT], Coro[Any]], Callable[[ContextT], Coro[Any]] ] + HookT = TypeVar("HookT", bound="Hook") + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: @@ -106,8 +115,6 @@ class _BaseCommand: class Invokable: - on_error: Optional[ErrorT] - def __init__(self, func: CallbackT, **kwargs): self.module: Any = None self.cog: Optional[Cog] @@ -148,6 +155,8 @@ def __init__(self, func: CallbackT, **kwargs): if hook := getattr(func, "__after_invoke__", None): self.after_invoke(hook) + self.on_error: Optional[Error] + @property def callback(self) -> CallbackT: return self._callback @@ -246,7 +255,7 @@ def has_error_handler(self) -> bool: """:class:`bool`: Checks whether the command has an error handler registered.""" return hasattr(self, "on_error") - def before_invoke(self, coro: Hook) -> Hook: + def before_invoke(self, coro: HookT) -> HookT: """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is @@ -273,7 +282,7 @@ def before_invoke(self, coro: Hook) -> Hook: self._before_invoke = coro return coro - def after_invoke(self, coro: Hook) -> Hook: + def after_invoke(self, coro: HookT) -> HookT: """A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index c3af75c236..75b82f32b6 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -40,11 +40,9 @@ from .view import StringView if TYPE_CHECKING: - import importlib.machinery - from discord.message import Message - from ._types import Check, CoroFunc + from ._types import CoroFunc __all__ = ( "when_mentioned", diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 019e7b8bee..781c94cb72 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -45,7 +45,6 @@ from .bot import AutoShardedBot, Bot from .cog import Cog from .core import Command - from .help import HelpCommand from .view import StringView __all__ = ("Context",) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 1c6a6bd0a9..b6578c26fb 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -25,7 +25,6 @@ from __future__ import annotations import asyncio -import datetime import functools import inspect import types diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 5b29871726..13db7cac21 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -104,9 +104,6 @@ ) -from ...errors import CommandError - - class ConversionError(CommandError): """Exception raised when a Converter class raises non-CommandError. diff --git a/discord/scheduled_events.py b/discord/scheduled_events.py index b7614f2a48..1771cdd370 100644 --- a/discord/scheduled_events.py +++ b/discord/scheduled_events.py @@ -35,7 +35,6 @@ ScheduledEventStatus, try_enum, ) -from .errors import ValidationError from .iterators import ScheduledEventSubscribersIterator from .mixins import Hashable from .object import Object @@ -351,7 +350,7 @@ async def edit( if end_time is MISSING and location.type is ScheduledEventLocationType.external: end_time = self.end_time if end_time is None: - raise ValidationError("end_time needs to be passed if location type is external.") + raise TypeError("end_time needs to be passed if location type is external.") if start_time is not MISSING: payload["scheduled_start_time"] = start_time.isoformat() From 1ea034259eaed1232d6a49a47dcd1677b0b0b259 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 09:38:50 -0400 Subject: [PATCH 09/54] Remove some unused imports and vars --- discord/commands/invokable.py | 1 - discord/ext/commands/converter.py | 2 -- discord/webhook/sync.py | 2 -- ok.py | 18 ++++++++++++++++++ 4 files changed, 18 insertions(+), 5 deletions(-) create mode 100644 ok.py diff --git a/discord/commands/invokable.py b/discord/commands/invokable.py index 5d1ae25572..e94eaaf4e8 100644 --- a/discord/commands/invokable.py +++ b/discord/commands/invokable.py @@ -5,7 +5,6 @@ import functools from typing import ( TYPE_CHECKING, - Generic, TypeVar, Callable, Coroutine, diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 5ec807f417..fc5e795942 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -464,8 +464,6 @@ def check(c): @staticmethod def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: - bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index 459c174209..e5ff135d50 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -45,8 +45,6 @@ Literal, Optional, Tuple, - Type, - TypeVar, Union, overload, ) diff --git a/ok.py b/ok.py new file mode 100644 index 0000000000..c88045726d --- /dev/null +++ b/ok.py @@ -0,0 +1,18 @@ +class AA: + def __init__(self): + print("AA") + +class AB: + def __init__(self): + print("AB") + +class BA(AB, AA): + def __init__(self): + super(AB, self).__init__() + +test = BA() + +import time +time.sleep(2) + +print(BA) From 633b987441d117ddf50342f9cd604a7499ac603e Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 11:14:12 -0400 Subject: [PATCH 10/54] Implement BaseContext --- discord/commands/__init__.py | 2 +- discord/commands/core.py | 2 +- discord/commands/{invokable.py => mixins.py} | 130 ++++++++++++++++++- discord/ext/commands/context.py | 62 ++------- discord/ext/commands/core.py | 4 +- 5 files changed, 136 insertions(+), 64 deletions(-) rename discord/commands/{invokable.py => mixins.py} (81%) diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py index 3b6c5b70b9..cc4a8d25bf 100644 --- a/discord/commands/__init__.py +++ b/discord/commands/__init__.py @@ -27,4 +27,4 @@ from .core import * from .options import * from .permissions import * -from .invokable import Invokable, _BaseCommand +from .mixins import Invokable, _BaseCommand diff --git a/discord/commands/core.py b/discord/commands/core.py index b8deb7a2e2..657b312c6c 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -61,8 +61,8 @@ from ..role import Role from ..threads import Thread from ..user import User -from .invokable import Invokable, _BaseCommand from ..utils import find, MISSING +from .mixins import Invokable, _BaseCommand from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice diff --git a/discord/commands/invokable.py b/discord/commands/mixins.py similarity index 81% rename from discord/commands/invokable.py rename to discord/commands/mixins.py index e94eaaf4e8..a7754b0976 100644 --- a/discord/commands/invokable.py +++ b/discord/commands/mixins.py @@ -10,12 +10,13 @@ Coroutine, Optional, Union, + Generic, Any, Dict, List, ) -from .. import utils +from .. import utils, abc from ..errors import ( ApplicationCommandError, CheckFailure, @@ -28,12 +29,22 @@ if TYPE_CHECKING: from typing_extensions import ParamSpec + from ..bot import Bot, AutoShardedBot from ..cog import Cog + from ..user import User, ClientUser + from ..member import Member + from ..guild import Guild + from ..message import Message + from ..interactions import Interaction + from ..abc import MessageableChannel + from ..voice_client import VoiceProtocol P = ParamSpec("P") + BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") + CogT = TypeVar("CogT", bound="Cog") CallbackT = TypeVar("CallbackT") - ContextT = TypeVar("ContextT", bound="Context") # TODO: yes + ContextT = TypeVar("ContextT", bound="BaseContext") T = TypeVar("T") Coro = Coroutine[Any, Any, T] @@ -45,8 +56,8 @@ ] Error = Union[ - Callable[[Cog, "Context[Any]", CommandError], Coro[Any]], - Callable[["Context[Any]", CommandError], Coro[Any]], + Callable[[Cog, "BaseContext[Any]", CommandError], Coro[Any]], + Callable[["BaseContext[Any]", CommandError], Coro[Any]], ] ErrorT = TypeVar("ErrorT", bound="Error") @@ -113,7 +124,112 @@ class _BaseCommand: __slots__ = () -class Invokable: +class BaseContext(abc.Messageable, Generic[BotT]): + def __init__( + self, + *, + bot: Bot, + command: Optional[Invokable], + args: List[Any] = utils.MISSING, + kwargs: Dict[str, Any] = utils.MISSING + ): + self.bot: Bot = bot + self.command: Optional[Invokable] = command + self.args: List[Any] = args or [] + self.kwargs: Dict[str, Any] = kwargs or {} + + async def invoke(self, command: Invokable[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: + r"""|coro| + + Calls a command with the arguments given. + + This is useful if you want to just call the callback that a + :class:`.Invokable` holds internally. + + .. note:: + + This does not handle converters, checks, cooldowns, before-invoke, + or after-invoke hooks in any matter. It calls the internal callback + directly as-if it was a regular function. + + You must take care in passing the proper arguments when + using this function. + + Parameters + ----------- + command: :class:`.Invokable` + The command that is going to be called. + \*args + The arguments to use. + \*\*kwargs + The keyword arguments to use. + + Raises + ------- + TypeError + The command argument to invoke is missing. + """ + return await command(self, *args, **kwargs) + + async def _get_channel(self) -> abc.Messageable: + return self.channel + + @property + def source(self) -> Union[Message, Interaction]: + """Union[:class:`Message`, :class:`Interaction`]: Property to return a message or interaction + depending on the context. + """ + raise NotImplementedError() + + @property + def cog(self) -> Optional[Cog]: + """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. + None if it does not exist.""" + + if self.command is None: + return None + return self.command.cog + + @utils.cached_property + def guild(self) -> Optional[Guild]: + """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. + None if not available.""" + return self.source.guild + + @utils.cached_property + def channel(self) -> MessageableChannel: + """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. + Shorthand for :attr:`.Message.channel`. + """ + return self.source.channel + + @utils.cached_property + def author(self) -> Union[User, Member]: + """Union[:class:`.User`, :class:`.Member`]: + Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` + """ + return self.source.author + + @utils.cached_property + def me(self) -> Union[Member, ClientUser]: + """Union[:class:`.Member`, :class:`.ClientUser`]: + Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message + message contexts, or when :meth:`Intents.guilds` is absent. + """ + # bot.user will never be None at this point. + return self.guild.me if self.guild and self.guild.me else self.bot.user # type: ignore + + @property + def voice_client(self) -> Optional[VoiceProtocol]: + r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" + return self.guild.voice_client if self.guild else None + + + + + + +class Invokable(Generic[CogT, P, T]): def __init__(self, func: CallbackT, **kwargs): self.module: Any = None self.cog: Optional[Cog] @@ -565,8 +681,8 @@ async def invoke(self, ctx: ContextT) -> None: # terminate the invoked_subcommand chain. # since we're in a regular command (and not a group) then # the invoked subcommand is None. - # ctx.invoked_subcommand = None - # ctx.subcommand_passed = None + ctx.invoked_subcommand = None + ctx.subcommand_passed = None injected = hooked_wrapped_callback(self, ctx, self.callback) await injected(*ctx.args, **ctx.kwargs) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 781c94cb72..7dd31758dc 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -31,6 +31,7 @@ import discord.abc import discord.utils from discord.message import Message +from ...commands.mixins import BaseContext if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -62,8 +63,8 @@ P = TypeVar("P") -class Context(discord.abc.Messageable, Generic[BotT]): - r"""Represents the context in which a command is being invoked under. +class Context(BaseContext, Generic[BotT]): + """Represents the context in which a command is being invoked under. This class contains a lot of metadata to help you understand more about the invocation context. This class is not created manually and is instead @@ -135,12 +136,10 @@ def __init__( command_failed: bool = False, current_parameter: Optional[inspect.Parameter] = None, ): + super().__init__(bot=bot, command=command, args=args, kwargs=kwargs) + self.message: Message = message - self.bot: BotT = bot - self.args: List[Any] = args or [] - self.kwargs: Dict[str, Any] = kwargs or {} self.prefix: Optional[str] = prefix - self.command: Optional[Command] = command self.view: StringView = view self.invoked_with: Optional[str] = invoked_with self.invoked_parents: List[str] = invoked_parents or [] @@ -150,6 +149,10 @@ def __init__( self.current_parameter: Optional[inspect.Parameter] = current_parameter self._state: ConnectionState = self.message._state + @property + def source(self) -> Message: + return self.message + async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| @@ -250,9 +253,6 @@ def valid(self) -> bool: """:class:`bool`: Checks if the invocation context is valid to be invoked with.""" return self.prefix is not None and self.command is not None - async def _get_channel(self) -> discord.abc.Messageable: - return self.channel - @property def clean_prefix(self) -> str: """:class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``. @@ -270,50 +270,6 @@ def clean_prefix(self) -> str: pattern = re.compile(r"<@!?%s>" % user.id) return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix) - @property - def cog(self) -> Optional[Cog]: - """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. - None if it does not exist.""" - - if self.command is None: - return None - return self.command.cog - - @discord.utils.cached_property - def guild(self) -> Optional[Guild]: - """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. - None if not available.""" - return self.message.guild - - @discord.utils.cached_property - def channel(self) -> MessageableChannel: - """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. - Shorthand for :attr:`.Message.channel`. - """ - return self.message.channel - - @discord.utils.cached_property - def author(self) -> Union[User, Member]: - """Union[:class:`~discord.User`, :class:`.Member`]: - Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` - """ - return self.message.author - - @discord.utils.cached_property - def me(self) -> Union[Member, ClientUser]: - """Union[:class:`.Member`, :class:`.ClientUser`]: - Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message - message contexts, or when :meth:`Intents.guilds` is absent. - """ - # bot.user will never be None at this point. - return self.guild.me if self.guild is not None and self.guild.me is not None else self.bot.user # type: ignore - - @property - def voice_client(self) -> Optional[VoiceProtocol]: - r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" - g = self.guild - return g.voice_client if g else None - async def send_help(self, *args: Any) -> Any: """send_help(entity=) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index b6578c26fb..c94f6f72eb 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -56,7 +56,7 @@ slash_command, user_command, ) -from ...commands.invokable import Invokable +from ...commands.mixins import Invokable from ...errors import * from .cog import Cog from .context import Context @@ -77,6 +77,7 @@ from discord.message import Message + from ...commands.mixins import CogT from ._types import Check, Coro, CoroFunc, Error, Hook @@ -113,7 +114,6 @@ MISSING: Any = discord.utils.MISSING T = TypeVar("T") -CogT = TypeVar("CogT", bound="Cog") CommandT = TypeVar("CommandT", bound="Command") ContextT = TypeVar("ContextT", bound="Context") # CHT = TypeVar('CHT', bound='Check') From 4d67d6403511b189e01fc51ed9cd36867d1a4793 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 11:35:50 -0400 Subject: [PATCH 11/54] Finish integrating BaseContext like bridging all the attributes and subclassed properties etc. --- discord/bot.py | 4 +- discord/commands/context.py | 136 +++++--------------------------- discord/commands/mixins.py | 27 ++++++- discord/ext/commands/context.py | 42 +--------- 4 files changed, 47 insertions(+), 162 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index 8b23ea40a8..514944a257 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -736,7 +736,9 @@ async def process_application_commands(self, interaction: Interaction, auto_sync return self.dispatch("application_command_auto_complete", interaction, command) ctx = await self.get_application_context(interaction) - ctx.command = command + if not ctx.command: + ctx.command = command + await self.invoke_application_command(ctx) async def on_application_command_auto_complete(self, interaction: Interaction, command: ApplicationCommand) -> None: diff --git a/discord/commands/context.py b/discord/commands/context.py index ec2991a75c..0c63d98f5d 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -29,23 +29,18 @@ import discord.abc from discord.interactions import InteractionMessage, InteractionResponse, Interaction from discord.webhook.async_ import Webhook +from .mixins import BaseContext if TYPE_CHECKING: from typing_extensions import ParamSpec import discord from .. import Bot - from ..state import ConnectionState - from ..voice_client import VoiceProtocol from .core import ApplicationCommand, Option - from ..interactions import Interaction, InteractionResponse, InteractionChannel - from ..guild import Guild - from ..member import Member + from ..interactions import Interaction, InteractionResponse from ..message import Message - from ..user import User from ..permissions import Permissions - from ..client import ClientUser from discord.webhook.async_ import Webhook from ..cog import Cog @@ -66,7 +61,7 @@ __all__ = ("ApplicationContext", "AutocompleteContext") -class ApplicationContext(discord.abc.Messageable): +class ApplicationContext(BaseContext): """Represents a Discord application command interaction context. This class is not created manually and is instead passed to application @@ -83,85 +78,29 @@ class ApplicationContext(discord.abc.Messageable): command: :class:`.ApplicationCommand` The command that this context belongs to. """ + command: Optional[ApplicationCommand] + + def __init__( + self, + bot: Bot, + interaction: Interaction, + *, + command: Optional[ApplicationCommand] = None, + args: List[Any] = None, + kwargs: Dict[str, Any] = None + ): + super().__init__(bot=bot, command=command, args=args, kwargs=kwargs) - def __init__(self, bot: Bot, interaction: Interaction): - self.bot = bot self.interaction = interaction # below attributes will be set after initialization - self.command: ApplicationCommand = None # type: ignore self.focused: Option = None # type: ignore self.value: str = None # type: ignore self.options: dict = None # type: ignore - self._state: ConnectionState = self.interaction._state - - async def _get_channel(self) -> Optional[InteractionChannel]: - return self.interaction.channel - - async def invoke( - self, - command: ApplicationCommand[CogT, P, T], - /, - *args: P.args, - **kwargs: P.kwargs, - ) -> T: - r"""|coro| - - Calls a command with the arguments given. - This is useful if you want to just call the callback that a - :class:`.ApplicationCommand` holds internally. - - .. note:: - - This does not handle converters, checks, cooldowns, pre-invoke, - or after-invoke hooks in any matter. It calls the internal callback - directly as-if it was a regular function. - You must take care in passing the proper arguments when - using this function. - - Parameters - ----------- - command: :class:`.ApplicationCommand` - The command that is going to be called. - \*args - The arguments to use. - \*\*kwargs - The keyword arguments to use. - - Raises - ------- - TypeError - The command argument to invoke is missing. - """ - return await command(self, *args, **kwargs) - - @cached_property - def channel(self) -> Optional[InteractionChannel]: - """Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]: - Returns the channel associated with this context's command. Shorthand for :attr:`.Interaction.channel`.""" - return self.interaction.channel - - @cached_property - def channel_id(self) -> Optional[int]: - """:class:`int`: Returns the ID of the channel associated with this context's command. - Shorthand for :attr:`.Interaction.channel_id`. - """ - return self.interaction.channel_id - - @cached_property - def guild(self) -> Optional[Guild]: - """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. - Shorthand for :attr:`.Interaction.guild`. - """ - return self.interaction.guild - - @cached_property - def guild_id(self) -> Optional[int]: - """:class:`int`: Returns the ID of the guild associated with this context's command. - Shorthand for :attr:`.Interaction.guild_id`. - """ - return self.interaction.guild_id + @property + def source(self) -> Interaction: + return self.interaction @cached_property def locale(self) -> Optional[str]: @@ -181,14 +120,6 @@ def guild_locale(self) -> Optional[str]: def app_permissions(self) -> Permissions: return self.interaction.app_permissions - @cached_property - def me(self) -> Optional[Union[Member, ClientUser]]: - """Union[:class:`.Member`, :class:`.ClientUser`]: - Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message - message contexts, or when :meth:`Intents.guilds` is absent. - """ - return self.interaction.guild.me if self.interaction.guild is not None else self.bot.user - @cached_property def message(self) -> Optional[Message]: """Optional[:class:`.Message`]: Returns the message sent with this context's command. @@ -196,25 +127,6 @@ def message(self) -> Optional[Message]: """ return self.interaction.message - @cached_property - def user(self) -> Optional[Union[Member, User]]: - """Union[:class:`.Member`, :class:`.User`]: Returns the user that sent this context's command. - Shorthand for :attr:`.Interaction.user`. - """ - return self.interaction.user - - author: Optional[Union[Member, User]] = user - - @property - def voice_client(self) -> Optional[VoiceProtocol]: - """Optional[:class:`.VoiceProtocol`]: Returns the voice client associated with this context's command. - Shorthand for :attr:`Interaction.guild.voice_client<~discord.Guild.voice_client>`, if applicable. - """ - if self.interaction.guild is None: - return None - - return self.interaction.guild.voice_client - @cached_property def response(self) -> InteractionResponse: """:class:`.InteractionResponse`: Returns the response object associated with this context's command. @@ -338,19 +250,9 @@ async def delete(self, *, delay: Optional[float] = None) -> None: def edit(self) -> Callable[..., Awaitable[InteractionMessage]]: return self.interaction.edit_original_message - @property - def cog(self) -> Optional[Cog]: - """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. - ``None`` if it does not exist. - """ - if self.command is None: - return None - - return self.command.cog - class AutocompleteContext: - """Represents context for a slash command's option autocomplete. + """Represents context for a slash command's option autocomplete. This ***does not*** inherent from :class:`.BaseContext`. This class is not created manually and is instead passed to an :class:`.Option`'s autocomplete callback. diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index a7754b0976..d7f65db837 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -38,6 +38,7 @@ from ..interactions import Interaction from ..abc import MessageableChannel from ..voice_client import VoiceProtocol + from ..state import ConnectionState P = ParamSpec("P") @@ -181,6 +182,10 @@ def source(self) -> Union[Message, Interaction]: """ raise NotImplementedError() + @property + def _state(self) -> ConnectionState: + return self.source._state + @property def cog(self) -> Optional[Cog]: """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. @@ -196,6 +201,11 @@ def guild(self) -> Optional[Guild]: None if not available.""" return self.source.guild + @utils.cached_property + def channel_id(self) -> Optional[int]: + """:class:`int`: Returns the ID of the guild associated with this context's command.""" + return getattr(self.source, "guild_id", self.guild.id if self.guild else None) + @utils.cached_property def channel(self) -> MessageableChannel: """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. @@ -203,6 +213,11 @@ def channel(self) -> MessageableChannel: """ return self.source.channel + @utils.cached_property + def channel_id(self) -> Optional[int]: + """:class:`int`: Returns the ID of the channel associated with this context's command.""" + return getattr(self.source, "channel_id", self.channel.id if self.channel else None) + @utils.cached_property def author(self) -> Union[User, Member]: """Union[:class:`.User`, :class:`.Member`]: @@ -210,6 +225,11 @@ def author(self) -> Union[User, Member]: """ return self.source.author + @property + def user(self) -> Union[User, Member]: + """Union[:class:`.User`, :class:`.Member`]: Alias for :attr:`BaseContext.author`.""" + return self.author + @utils.cached_property def me(self) -> Union[Member, ClientUser]: """Union[:class:`.Member`, :class:`.ClientUser`]: @@ -653,7 +673,8 @@ async def call_after_hooks(self, ctx: ContextT) -> None: await hook(ctx) async def _parse_arguments(self, ctx: ContextT) -> None: - return + """Parses arguments and attaches them to the context class (Union[:class:`~ext.commands.Context`, :class:`.ApplicationContext`])""" + raise NotImplementedError() async def prepare(self, ctx: ContextT) -> None: ctx.command = self @@ -681,8 +702,8 @@ async def invoke(self, ctx: ContextT) -> None: # terminate the invoked_subcommand chain. # since we're in a regular command (and not a group) then # the invoked subcommand is None. - ctx.invoked_subcommand = None - ctx.subcommand_passed = None + # ctx.invoked_subcommand = None + # ctx.subcommand_passed = None injected = hooked_wrapped_callback(self, ctx, self.callback) await injected(*ctx.args, **ctx.kwargs) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 7dd31758dc..fe1cf0bd8a 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -36,13 +36,6 @@ if TYPE_CHECKING: from typing_extensions import ParamSpec - from discord.abc import MessageableChannel - from discord.guild import Guild - from discord.member import Member - from discord.state import ConnectionState - from discord.user import ClientUser, User - from discord.voice_client import VoiceProtocol - from .bot import AutoShardedBot, Bot from .cog import Cog from .core import Command @@ -118,6 +111,7 @@ class Context(BaseContext, Generic[BotT]): A boolean that indicates if the command failed to be parsed, checked, or invoked. """ + command: Optional[Command] def __init__( self, @@ -147,45 +141,11 @@ def __init__( self.subcommand_passed: Optional[str] = subcommand_passed self.command_failed: bool = command_failed self.current_parameter: Optional[inspect.Parameter] = current_parameter - self._state: ConnectionState = self.message._state @property def source(self) -> Message: return self.message - async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: - r"""|coro| - - Calls a command with the arguments given. - - This is useful if you want to just call the callback that a - :class:`.Command` holds internally. - - .. note:: - - This does not handle converters, checks, cooldowns, pre-invoke, - or after-invoke hooks in any matter. It calls the internal callback - directly as-if it was a regular function. - - You must take care in passing the proper arguments when - using this function. - - Parameters - ----------- - command: :class:`.Command` - The command that is going to be called. - \*args - The arguments to use. - \*\*kwargs - The keyword arguments to use. - - Raises - ------- - TypeError - The command argument to invoke is missing. - """ - return await command(self, *args, **kwargs) - async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: """|coro| From f9b69be1f4f7b9426a2042279ae97b70b1f57f7a Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 11:36:25 -0400 Subject: [PATCH 12/54] Remove test file --- ok.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 ok.py diff --git a/ok.py b/ok.py deleted file mode 100644 index c88045726d..0000000000 --- a/ok.py +++ /dev/null @@ -1,18 +0,0 @@ -class AA: - def __init__(self): - print("AA") - -class AB: - def __init__(self): - print("AB") - -class BA(AB, AA): - def __init__(self): - super(AB, self).__init__() - -test = BA() - -import time -time.sleep(2) - -print(BA) From cf1aee8c03ed709c5882e48db583467e55818226 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 11:56:32 -0400 Subject: [PATCH 13/54] Move parents and root_parent --- discord/commands/mixins.py | 28 ++++++++++++++++++++++++++++ discord/ext/commands/core.py | 30 ------------------------------ 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index d7f65db837..2bebd9a5d9 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -309,6 +309,34 @@ def callback(self, func: CallbackT) -> None: def cooldown(self): return self._buckets._cooldown + @property + def parents(self) -> List[Invokable]: + """List[:class:`Invokable`]: Retrieves the parents of this command. + + If the command has no parents then it returns an empty :class:`list`. + + For example in commands ``?a b c test``, the parents are ``[c, b, a]``. + """ + entries = [] + command = self + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command) + + return entries + + @property + def root_parent(self) -> Optional[Invokable]: + """Optional[:class:`Invokable`]: Retrieves the root parent of this command. + + If the command has no parents then it returns ``None``. + + For example in commands ``?a b c test``, the root parent is ``a``. + """ + if not self.parent: + return None + return self.parents[-1] + @property def full_parent_name(self) -> Optional[str]: """:class:`str`: Retrieves the fully qualified parent command name. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index c94f6f72eb..689e33bb7b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -522,36 +522,6 @@ def clean_params(self) -> Dict[str, inspect.Parameter]: return result - @property - def parents(self) -> List[Group]: - """List[:class:`Group`]: Retrieves the parents of this command. - - If the command has no parents then it returns an empty :class:`list`. - - For example in commands ``?a b c test``, the parents are ``[c, b, a]``. - - .. versionadded:: 1.1 - """ - entries = [] - command = self - while command.parent is not None: # type: ignore - command = command.parent # type: ignore - entries.append(command) - - return entries - - @property - def root_parent(self) -> Optional[Group]: - """Optional[:class:`Group`]: Retrieves the root parent of this command. - - If the command has no parents then it returns ``None``. - - For example in commands ``?a b c test``, the root parent is ``a``. - """ - if not self.parent: - return None - return self.parents[-1] - async def _parse_arguments(self, ctx: Context) -> None: ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = {} From 63e230edf6f676102a5f08347f0121d73775b43e Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 12:26:16 -0400 Subject: [PATCH 14/54] Implement reinvoke (ctx & cmd) and extra ctx attrs --- discord/commands/context.py | 49 +++++++++++++++++++++++++++-- discord/commands/core.py | 1 + discord/commands/mixins.py | 55 ++++++++++++++++++++++++++------- discord/ext/commands/context.py | 44 +++----------------------- discord/ext/commands/core.py | 18 ----------- 5 files changed, 96 insertions(+), 71 deletions(-) diff --git a/discord/commands/context.py b/discord/commands/context.py index 0c63d98f5d..7787286e45 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -87,9 +87,10 @@ def __init__( *, command: Optional[ApplicationCommand] = None, args: List[Any] = None, - kwargs: Dict[str, Any] = None + kwargs: Dict[str, Any] = None, + **kwargs2 ): - super().__init__(bot=bot, command=command, args=args, kwargs=kwargs) + super().__init__(bot=bot, command=command, args=args, kwargs=kwargs, **kwargs2) self.interaction = interaction @@ -98,6 +99,50 @@ def __init__( self.value: str = None # type: ignore self.options: dict = None # type: ignore + async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: + """|coro| + + Calls the command again. + + This is similar to :meth:`~.BaseContext.invoke` except that it bypasses + checks, cooldowns, and error handlers. + + .. note:: + + If you want to bypass :exc:`.UserInputError` derived exceptions, + it is recommended to use the regular :meth:`~.Context.invoke` + as it will work more naturally. After all, this will end up + using the old arguments the user has used and will thus just + fail again. + + Parameters + ------------ + call_hooks: :class:`bool` + Whether to call the before and after invoke hooks. + restart: :class:`bool` + Whether to start the call chain from the very beginning + or where we left off (i.e. the command that caused the error). + The default is to start where we left off. + + Raises + ------- + ValueError + The context to reinvoke is not valid. + """ + cmd = self.command + if cmd is None: + raise ValueError("This context is not valid.") + + if restart: + to_call = cmd.root_parent or cmd + else: + to_call = cmd + + try: + await to_call.reinvoke(self, call_hooks=call_hooks) + finally: + self.command = cmd + @property def source(self) -> Interaction: return self.interaction diff --git a/discord/commands/core.py b/discord/commands/core.py index 657b312c6c..c07a9624e6 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -773,6 +773,7 @@ async def invoke(self, ctx: ApplicationContext) -> None: command = find(lambda x: x.name == option["name"], self.subcommands) option["resolved"] = resolved ctx.interaction.data = option + ctx.invoked_subcommand = command await command.invoke(ctx) async def invoke_autocomplete_callback(self, ctx: AutocompleteContext) -> None: diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 2bebd9a5d9..ec5f2628be 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -128,17 +128,37 @@ class _BaseCommand: class BaseContext(abc.Messageable, Generic[BotT]): def __init__( self, - *, bot: Bot, command: Optional[Invokable], args: List[Any] = utils.MISSING, - kwargs: Dict[str, Any] = utils.MISSING + kwargs: Dict[str, Any] = utils.MISSING, + *, + invoked_with: Optional[str] = None, + invoked_parents: List[str] = utils.MISSING, + invoked_subcommand: Optional[Invokable] = None, + subcommand_passed: Optional[str] = None, + command_failed: bool = False + ): self.bot: Bot = bot self.command: Optional[Invokable] = command self.args: List[Any] = args or [] self.kwargs: Dict[str, Any] = kwargs or {} + self.invoked_with: Optional[str] = invoked_with + if not self.invoked_with and command: + self.invoked_with = command.name + + self.invoked_parents: List[str] = invoked_parents or [] + if not self.invoked_parents and command: + self.invoked_parents = [i.name for i in command.parents] + + # This will always be None for slash commands + self.subcommand_passed: Optional[str] = subcommand_passed + + self.invoked_subcommand: Optional[Invokable] = invoked_subcommand + self.command_failed: bool = command_failed + async def invoke(self, command: Invokable[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| @@ -202,7 +222,7 @@ def guild(self) -> Optional[Guild]: return self.source.guild @utils.cached_property - def channel_id(self) -> Optional[int]: + def guild_id(self) -> Optional[int]: """:class:`int`: Returns the ID of the guild associated with this context's command.""" return getattr(self.source, "guild_id", self.guild.id if self.guild else None) @@ -241,14 +261,10 @@ def me(self) -> Union[Member, ClientUser]: @property def voice_client(self) -> Optional[VoiceProtocol]: - r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" + """Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" return self.guild.voice_client if self.guild else None - - - - class Invokable(Generic[CogT, P, T]): def __init__(self, func: CallbackT, **kwargs): self.module: Any = None @@ -730,11 +746,28 @@ async def invoke(self, ctx: ContextT) -> None: # terminate the invoked_subcommand chain. # since we're in a regular command (and not a group) then # the invoked subcommand is None. - # ctx.invoked_subcommand = None - # ctx.subcommand_passed = None + ctx.invoked_subcommand = None + ctx.subcommand_passed = None injected = hooked_wrapped_callback(self, ctx, self.callback) await injected(*ctx.args, **ctx.kwargs) + async def reinvoke(self, ctx: ContextT, *, call_hooks: bool = False) -> None: + ctx.command = self + await self._parse_arguments(ctx) + + if call_hooks: + await self.call_before_hooks(ctx) + + ctx.invoked_subcommand = None + try: + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + await self.call_after_hooks(ctx) + def copy(self): """Creates a copy of this command. @@ -746,7 +779,7 @@ def copy(self): ret = self.__class__(self.callback, **self.__original_kwargs__) return self._ensure_assignment_on_copy(ret) - def _ensure_assignment_on_copy(self, other): + def _ensure_assignment_on_copy(self, other: Invokable): other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke if self.checks != other.checks: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index fe1cf0bd8a..5d0914bdb4 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -32,6 +32,7 @@ import discord.utils from discord.message import Message from ...commands.mixins import BaseContext +from ...commands.context import ApplicationContext if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -123,59 +124,22 @@ def __init__( kwargs: Dict[str, Any] = MISSING, prefix: Optional[str] = None, command: Optional[Command] = None, - invoked_with: Optional[str] = None, - invoked_parents: List[str] = MISSING, - invoked_subcommand: Optional[Command] = None, - subcommand_passed: Optional[str] = None, - command_failed: bool = False, current_parameter: Optional[inspect.Parameter] = None, + **kwargs2, ): - super().__init__(bot=bot, command=command, args=args, kwargs=kwargs) + super().__init__(bot=bot, command=command, args=args, kwargs=kwargs, **kwargs2) self.message: Message = message self.prefix: Optional[str] = prefix self.view: StringView = view - self.invoked_with: Optional[str] = invoked_with - self.invoked_parents: List[str] = invoked_parents or [] - self.invoked_subcommand: Optional[Command] = invoked_subcommand - self.subcommand_passed: Optional[str] = subcommand_passed - self.command_failed: bool = command_failed self.current_parameter: Optional[inspect.Parameter] = current_parameter @property def source(self) -> Message: return self.message + @discord.utils.copy_doc(ApplicationContext.reinvoke) async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: - """|coro| - - Calls the command again. - - This is similar to :meth:`~.Context.invoke` except that it bypasses - checks, cooldowns, and error handlers. - - .. note:: - - If you want to bypass :exc:`.UserInputError` derived exceptions, - it is recommended to use the regular :meth:`~.Context.invoke` - as it will work more naturally. After all, this will end up - using the old arguments the user has used and will thus just - fail again. - - Parameters - ------------ - call_hooks: :class:`bool` - Whether to call the before and after invoke hooks. - restart: :class:`bool` - Whether to start the call chain from the very beginning - or where we left off (i.e. the command that caused the error). - The default is to start where we left off. - - Raises - ------- - ValueError - The context to reinvoke is not valid. - """ cmd = self.command view = self.view if cmd is None: diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 689e33bb7b..0403b4f1f2 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -572,24 +572,6 @@ async def _parse_arguments(self, ctx: Context) -> None: if not self.ignore_extra and not view.eof: raise TooManyArguments(f"Too many arguments passed to {self.qualified_name}") - # TODO: - async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: - ctx.command = self - await self._parse_arguments(ctx) - - if call_hooks: - await self.call_before_hooks(ctx) - - ctx.invoked_subcommand = None - try: - await self.callback(*ctx.args, **ctx.kwargs) # type: ignore - except: - ctx.command_failed = True - raise - finally: - if call_hooks: - await self.call_after_hooks(ctx) - @property def short_doc(self) -> str: """:class:`str`: Gets the "short" documentation of a command. From 36085e34919d994bf25321b717bcc0a5b1c5fbc1 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 12:32:56 -0400 Subject: [PATCH 15/54] Some Runtime fixes --- discord/commands/mixins.py | 51 ++++++++++++++++++------------------ discord/ext/commands/core.py | 46 ++++++++++++++++---------------- 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index ec5f2628be..9ccf799da1 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -30,7 +30,6 @@ from typing_extensions import ParamSpec from ..bot import Bot, AutoShardedBot - from ..cog import Cog from ..user import User, ClientUser from ..member import Member from ..guild import Guild @@ -41,33 +40,34 @@ from ..state import ConnectionState P = ParamSpec("P") +else: + P = TypeVar("P") - BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") - CogT = TypeVar("CogT", bound="Cog") - CallbackT = TypeVar("CallbackT") - ContextT = TypeVar("ContextT", bound="BaseContext") +BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") +CogT = TypeVar("CogT", bound="Cog") +CallbackT = TypeVar("CallbackT") +ContextT = TypeVar("ContextT", bound="BaseContext") - T = TypeVar("T") - Coro = Coroutine[Any, Any, T] - MaybeCoro = Union[T, Coro[T]] +T = TypeVar("T") +Coro = Coroutine[Any, Any, T] +MaybeCoro = Union[T, Coro[T]] - Check = Union[ - Callable[[Cog, ContextT], MaybeCoro[bool]], # TODO: replace with stardized context superclass - Callable[[ContextT], MaybeCoro[bool]], # as well as for the others - ] +Check = Union[ + Callable[[CogT, ContextT], MaybeCoro[bool]], + Callable[[ContextT], MaybeCoro[bool]], +] - Error = Union[ - Callable[[Cog, "BaseContext[Any]", CommandError], Coro[Any]], - Callable[["BaseContext[Any]", CommandError], Coro[Any]], - ] - ErrorT = TypeVar("ErrorT", bound="Error") - - Hook = Union[ - Callable[[Cog, ContextT], Coro[Any]], - Callable[[ContextT], Coro[Any]] - ] - HookT = TypeVar("HookT", bound="Hook") +Error = Union[ + Callable[[CogT, "BaseContext[Any]", CommandError], Coro[Any]], + Callable[["BaseContext[Any]", CommandError], Coro[Any]], +] +ErrorT = TypeVar("ErrorT", bound="Error") +Hook = Union[ + Callable[[CogT, ContextT], Coro[Any]], + Callable[[ContextT], Coro[Any]] +] +HookT = TypeVar("HookT", bound="Hook") def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: @@ -269,7 +269,7 @@ class Invokable(Generic[CogT, P, T]): def __init__(self, func: CallbackT, **kwargs): self.module: Any = None self.cog: Optional[Cog] - self.parent: Optional[Invokable] = (parent := kwargs.get("parent")) if isinstance(parent, _BaseCommand) else None + self.parent: Optional[Invokable] = parent if isinstance((parent := kwargs.get("parent")), _BaseCommand) else None self.callback: CallbackT = func self.name: str = str(kwargs.get("name", func.__name__)) @@ -708,7 +708,8 @@ async def call_after_hooks(self, ctx: ContextT) -> None: # call the cog local hook if applicable: if cog is not None: - hook = Cog._get_overridden_method(cog.cog_after_invoke) + # :troll: + hook = cog.__class__._get_overridden_method(cog.cog_after_invoke) if hook is not None: await hook(ctx) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 0403b4f1f2..148aabb031 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -56,7 +56,7 @@ slash_command, user_command, ) -from ...commands.mixins import Invokable +from ...commands.mixins import Invokable, CogT from ...errors import * from .cog import Cog from .context import Context @@ -77,10 +77,30 @@ from discord.message import Message - from ...commands.mixins import CogT from ._types import Check, Coro, CoroFunc, Error, Hook + P = ParamSpec("P") + + CommandT = TypeVar("CommandT", bound="Command") + ContextT = TypeVar("ContextT", bound="Context") + # CHT = TypeVar('CHT', bound='Check') + GroupT = TypeVar("GroupT", bound="Group") + HookT = TypeVar("HookT", bound="Hook") + ErrorT = TypeVar("ErrorT", bound="Error") + + CallbackT = Union[ + Callable[ + [Concatenate[CogT, ContextT, P]], + Coro[T] + ], + Callable[ + [Concatenate[ContextT, P]], + Coro[T] + ], + ] +else: + P = TypeVar("P") __all__ = ( "Command", "Group", @@ -114,28 +134,6 @@ MISSING: Any = discord.utils.MISSING T = TypeVar("T") -CommandT = TypeVar("CommandT", bound="Command") -ContextT = TypeVar("ContextT", bound="Context") -# CHT = TypeVar('CHT', bound='Check') -GroupT = TypeVar("GroupT", bound="Group") -HookT = TypeVar("HookT", bound="Hook") -ErrorT = TypeVar("ErrorT", bound="Error") - -if TYPE_CHECKING: - P = ParamSpec("P") -else: - P = TypeVar("P") - -CallbackT = Union[ - Callable[ - [Concatenate[CogT, ContextT, P]], - Coro[T] - ], - Callable[ - [Concatenate[ContextT, P]], - Coro[T] - ], -] def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: From 07247a3d857d051f3cdc2e9bab9c2d92adcc3d5d Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 12:55:01 -0400 Subject: [PATCH 16/54] Last minute fixes --- discord/commands/core.py | 21 ++++--- discord/commands/mixins.py | 59 ++++++++++++-------- discord/ext/commands/core.py | 103 +++++------------------------------ 3 files changed, 60 insertions(+), 123 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index c07a9624e6..0988ac65de 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -50,6 +50,7 @@ from ..channel import _threaded_guild_channel_factory from ..enums import MessageType, SlashCommandOptionType, try_enum, Enum as DiscordEnum from ..errors import ( + CommandError, ApplicationCommandError, ApplicationCommandInvokeError, ClientException, @@ -86,20 +87,19 @@ from .. import Permissions from ..cog import Cog + from .mixins import BaseContext -T = TypeVar("T") -CogT = TypeVar("CogT", bound="Cog") -Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) - -if TYPE_CHECKING: P = ParamSpec("P") else: P = TypeVar("P") -def wrap_callback(coro): - from ..ext.commands.errors import CommandError +T = TypeVar("T") +CogT = TypeVar("CogT", bound="Cog") +Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) + +def wrap_callback(coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): try: @@ -195,8 +195,8 @@ def __eq__(self, other) -> bool: def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) - def _set_cog(self, cog): - self.cog = cog + async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + ctx.bot.dispatch("application_command_error", ctx, error) class SlashCommand(ApplicationCommand): @@ -716,8 +716,7 @@ def create_subgroup( """ if self.parent is not None: - # TODO: Improve this error message - raise Exception("a subgroup cannot have a subgroup") + raise Exception("A command subgroup can only have commands and not any more groups.") sub_command_group = SlashCommandGroup(name, description, guild_ids, parent=self, **kwargs) self.subcommands.append(sub_command_group) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 9ccf799da1..305ef6cf7e 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -23,6 +23,7 @@ CommandError, CommandInvokeError, DisabledCommand, + CommandOnCooldown, ) from .cooldowns import BucketType, CooldownMapping, MaxConcurrency @@ -38,6 +39,7 @@ from ..abc import MessageableChannel from ..voice_client import VoiceProtocol from ..state import ConnectionState + from ..cog import Cog P = ParamSpec("P") else: @@ -267,13 +269,14 @@ def voice_client(self) -> Optional[VoiceProtocol]: class Invokable(Generic[CogT, P, T]): def __init__(self, func: CallbackT, **kwargs): - self.module: Any = None - self.cog: Optional[Cog] - self.parent: Optional[Invokable] = parent if isinstance((parent := kwargs.get("parent")), _BaseCommand) else None self.callback: CallbackT = func + self.parent: Optional[Invokable] = parent if isinstance((parent := kwargs.get("parent")), _BaseCommand) else None + self.cog: Optional[CogT] + self.module: Any = None self.name: str = str(kwargs.get("name", func.__name__)) self.enabled: bool = kwargs.get("enabled", True) + self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) # checks if checks := getattr(func, "__commands_checks__", []): @@ -546,27 +549,30 @@ async def can_run(self, ctx: ContextT) -> bool: finally: ctx.command = original - # depends on what to do with the application_command_error event + async def _dispatch_error(self, ctx: ContextT, error: Exception) -> None: + # since I don't want to copy paste code, subclassed Contexts + # dispatch it to their corresponding events + raise NotImplementedError() - # async def dispatch_error(self, ctx: ContextT, error: Exception) -> None: - # ctx.command_failed = True - # cog = self.cog + async def dispatch_error(self, ctx: ContextT, error: Exception) -> None: + ctx.command_failed = True + cog = self.cog - # if coro := getattr(self, "on_error", None): - # injected = wrap_callback(coro) - # if cog is not None: - # await injected(cog, ctx, error) - # else: - # await injected(ctx, error) + if coro := getattr(self, "on_error", None): + injected = wrap_callback(coro) + if cog is not None: + await injected(cog, ctx, error) + else: + await injected(ctx, error) - # try: - # if cog is not None: - # local = cog.__class__._get_overridden_method(cog.cog_command_error) - # if local is not None: - # wrapped = wrap_callback(local) - # await wrapped(ctx, error) - # finally: - # ctx.bot.dispatch("application_command_error", ctx, error) + try: + if cog is not None: + local = cog.__class__._get_overridden_method(cog.cog_command_error) + if local is not None: + wrapped = wrap_callback(local) + await wrapped(ctx, error) + finally: + await self._dispatch_error(ctx, error) def add_check(self, func: Check) -> None: """Adds a check to the command. @@ -732,8 +738,12 @@ async def prepare(self, ctx: ContextT) -> None: await self._max_concurrency.acquire(ctx) # type: ignore try: - self._prepare_cooldowns(ctx) - await self._parse_arguments(ctx) + if self.cooldown_after_parsing: + await self._parse_arguments(ctx) + self._prepare_cooldowns(ctx) + else: + self._prepare_cooldowns(ctx) + await self._parse_arguments(ctx) await self.call_before_hooks(ctx) except: @@ -805,3 +815,6 @@ def _update_copy(self, kwargs: Dict[str, Any]): return self._ensure_assignment_on_copy(copy) else: return self.copy() + + def _set_cog(self, cog: CogT): + self.cog = cog diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 148aabb031..789aa8c6d0 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -24,11 +24,9 @@ """ from __future__ import annotations -import asyncio import functools import inspect import types - from typing import ( TYPE_CHECKING, Any, @@ -56,7 +54,14 @@ slash_command, user_command, ) -from ...commands.mixins import Invokable, CogT +from ...commands.mixins import ( + CogT, + Invokable, + hooked_wrapped_callback, + unwrap_function, + wrap_callback, +) +from ...enums import ChannelType from ...errors import * from .cog import Cog from .context import Context @@ -68,10 +73,10 @@ DynamicCooldownMapping, MaxConcurrency, ) -from ...enums import ChannelType - from .errors import * +T = TypeVar("T") + if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard @@ -98,9 +103,10 @@ Coro[T] ], ] - else: P = TypeVar("P") + + __all__ = ( "Command", "Group", @@ -133,19 +139,6 @@ MISSING: Any = discord.utils.MISSING -T = TypeVar("T") - - -def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: - partial = functools.partial - while True: - if hasattr(function, "__wrapped__"): - function = function.__wrapped__ - elif isinstance(function, partial): - function = function.func - else: - return function - def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]: signature = inspect.signature(function) @@ -170,46 +163,6 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A return params -def wrap_callback(coro): - @functools.wraps(coro) - async def wrapped(*args, **kwargs): - try: - ret = await coro(*args, **kwargs) - except CommandError: - raise - except asyncio.CancelledError: - return - except Exception as exc: - raise CommandInvokeError(exc) from exc - return ret - - return wrapped - - -def hooked_wrapped_callback(command, ctx, coro): - @functools.wraps(coro) - async def wrapped(*args, **kwargs): - try: - ret = await coro(*args, **kwargs) - except CommandError: - ctx.command_failed = True - raise - except asyncio.CancelledError: - ctx.command_failed = True - return - except Exception as exc: - ctx.command_failed = True - raise CommandInvokeError(exc) from exc - finally: - if command._max_concurrency is not None: - await command._max_concurrency.release(ctx) - - await command.call_after_hooks(ctx) - return ret - - return wrapped - - class _CaseInsensitiveDict(dict): def __contains__(self, k): return super().__contains__(k.casefold()) @@ -361,14 +314,6 @@ def __init__( self.require_var_positional: bool = kwargs.get("require_var_positional", False) self.ignore_extra: bool = kwargs.get("ignore_extra", True) - # TODO: maybe??? - # self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) - # TODO: typing - # self.cog: Optional[CogT] = None - - # bandaid for the fact that sometimes parent can be the bot instance - parent = kwargs.get("parent") - self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore @property def callback( @@ -392,28 +337,8 @@ def callback( self.params = get_signature_parameters(func, globalns) - async def dispatch_error(self, ctx: Context, error: Exception) -> None: - ctx.command_failed = True - cog = self.cog - try: - coro = self.on_error - except AttributeError: - pass - else: - injected = wrap_callback(coro) - if cog is not None: - await injected(cog, ctx, error) - else: - await injected(ctx, error) - - try: - if cog is not None: - local = Cog._get_overridden_method(cog.cog_command_error) - if local is not None: - wrapped = wrap_callback(local) - await wrapped(ctx, error) - finally: - ctx.bot.dispatch("command_error", ctx, error) + async def _dispatch_error(self, ctx: Context, error: Exception) -> None: + ctx.bot.dispatch("command_error", ctx, error) async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: required = param.default is param.empty From b4b6fdd55567c3891cd46fb0c9e90fe00f697748 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 15:33:24 -0400 Subject: [PATCH 17/54] Make docs great again --- discord/commands/__init__.py | 2 +- discord/commands/core.py | 192 +++++----------------- discord/commands/mixins.py | 277 +++++++++++++++++++++----------- discord/ext/commands/context.py | 34 +--- discord/ext/commands/core.py | 41 +---- docs/api.rst | 92 +++++++---- 6 files changed, 285 insertions(+), 353 deletions(-) diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py index cc4a8d25bf..7a91ea8f15 100644 --- a/discord/commands/__init__.py +++ b/discord/commands/__init__.py @@ -27,4 +27,4 @@ from .core import * from .options import * from .permissions import * -from .mixins import Invokable, _BaseCommand +from .mixins import Invokable, _BaseCommand, BaseContext diff --git a/discord/commands/core.py b/discord/commands/core.py index 0988ac65de..a813b6cbb4 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -26,7 +26,6 @@ from __future__ import annotations import asyncio -import functools import inspect import re import types @@ -49,13 +48,7 @@ from ..channel import _threaded_guild_channel_factory from ..enums import MessageType, SlashCommandOptionType, try_enum, Enum as DiscordEnum -from ..errors import ( - CommandError, - ApplicationCommandError, - ApplicationCommandInvokeError, - ClientException, - ValidationError, -) +from ..errors import ClientException, ValidationError from ..member import Member from ..message import Attachment, Message from ..object import Object @@ -63,7 +56,7 @@ from ..threads import Thread from ..user import User from ..utils import find, MISSING -from .mixins import Invokable, _BaseCommand +from .mixins import Invokable, _BaseCommand, CogT from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice @@ -99,59 +92,6 @@ Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) -def wrap_callback(coro): - @functools.wraps(coro) - async def wrapped(*args, **kwargs): - try: - ret = await coro(*args, **kwargs) - except ApplicationCommandError: - raise - except CommandError: - raise - except asyncio.CancelledError: - return - except Exception as exc: - raise ApplicationCommandInvokeError(exc) from exc - return ret - - return wrapped - - -def hooked_wrapped_callback(command, ctx, coro): - from ..ext.commands.errors import CommandError - - @functools.wraps(coro) - async def wrapped(arg): - try: - ret = await coro(arg) - except ApplicationCommandError: - raise - except CommandError: - raise - except asyncio.CancelledError: - return - except Exception as exc: - raise ApplicationCommandInvokeError(exc) from exc - finally: - if hasattr(command, "_max_concurrency") and command._max_concurrency is not None: - await command._max_concurrency.release(ctx) - await command.call_after_hooks(ctx) - return ret - - return wrapped - - -def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: - partial = functools.partial - while True: - if hasattr(function, "__wrapped__"): - function = function.__wrapped__ - elif isinstance(function, partial): - function = function.func - else: - return function - - def _validate_names(obj): validate_chat_input_name(obj.name) if obj.name_localizations: @@ -167,6 +107,16 @@ def _validate_descriptions(obj): class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): + """Base class for all Application Commands, including: + + - :class:`.SlashCommand` + - :class:`.SlashCommandGroup` + - :class:`ContextMenuCommand` which in turn is a superclass of + - :class:`MessageCommand` and + - :class:`UserCommand` + + This is a subclass of :class:`.Invokable`. + """ __original_kwargs__: Dict[str, Any] cog = None @@ -174,7 +124,6 @@ def __init__(self, func: Callable, **kwargs) -> None: super().__init__(func, **kwargs) self.id: Optional[int] = kwargs.get("id") self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None) - self.parent = kwargs.get("parent") # Permissions self.default_member_permissions: Optional["Permissions"] = getattr( @@ -200,46 +149,27 @@ async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: class SlashCommand(ApplicationCommand): - r"""A class that implements the protocol for a slash command. + """A class that implements the protocol for a slash command. These are not created manually, instead they are created via the decorator or functional interface. + This is a subclass of :class:`.Invokable`. + .. versionadded:: 2.0 Attributes ----------- - name: :class:`str` - The name of the command. - callback: :ref:`coroutine ` - The coroutine that is executed when the command is called. - description: Optional[:class:`str`] - The description for the command. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. options: List[:class:`Option`] The parameters for this command. parent: Optional[:class:`SlashCommandGroup`] - The parent group that this command belongs to. ``None`` if there - isn't one. - mention: :class:`str` - Returns a string that allows you to mention the slash command. + The parent group that this command belongs to. guild_only: :class:`bool` Whether the command should only be usable inside a guild. default_member_permissions: :class:`~discord.Permissions` The default permissions a member needs to be able to run the command. - cog: Optional[:class:`Cog`] - The cog that this command belongs to. ``None`` if there isn't one. - checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] - A list of predicates that verifies if the command could be executed - with the given :class:`.ApplicationContext` as the sole parameter. If an exception - is necessary to be thrown to signal failure, then one inherited from - :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then - :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` - event. - cooldown: Optional[:class:`~discord.ext.commands.Cooldown`] - The cooldown applied when the command is invoked. ``None`` if the command - doesn't have a cooldown. name_localizations: Optional[Dict[:class:`str`, :class:`str`]] The name localizations for this command. The values of this should be ``"locale": "name"``. See `here `_ for a list of valid locales. @@ -395,6 +325,7 @@ def is_subcommand(self) -> bool: @property def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the slash command.""" return f"" def to_dict(self) -> Dict: @@ -419,6 +350,8 @@ def to_dict(self) -> Dict: return as_dict async def _parse_arguments(self, ctx: ApplicationContext) -> None: + ctx.args = [ctx] if self.cog is None else [self.cog, ctx] + # TODO: Parse the args better kwargs = {} for arg in ctx.interaction.data.get("options", []): @@ -513,7 +446,6 @@ async def _parse_arguments(self, ctx: ApplicationContext) -> None: if o._parameter_name not in kwargs: kwargs[o._parameter_name] = o.default - ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = kwargs async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): @@ -542,15 +474,17 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): class SlashCommandGroup(ApplicationCommand): - r"""A class that implements the protocol for a slash command group. + """A class that implements the protocol for a slash command group. These can be created manually, but they should be created via the decorator or functional interface. + This is a subclass of :class:`.Invokable`. + + .. versionadded:: 2.0 + Attributes ----------- - name: :class:`str` - The name of the command. description: Optional[:class:`str`] The description for the command. guild_ids: Optional[List[:class:`int`]] @@ -562,13 +496,6 @@ class SlashCommandGroup(ApplicationCommand): Whether the command should only be usable inside a guild. default_member_permissions: :class:`~discord.Permissions` The default permissions a member needs to be able to run the command. - checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] - A list of predicates that verifies if the command could be executed - with the given :class:`.ApplicationContext` as the sole parameter. If an exception - is necessary to be thrown to signal failure, then one inherited from - :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then - :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` - event. name_localizations: Optional[Dict[:class:`str`, :class:`str`]] The name localizations for this command. The values of this should be ``"locale": "name"``. See `here `_ for a list of valid locales. @@ -801,37 +728,22 @@ def _set_cog(self, cog): class ContextMenuCommand(ApplicationCommand): - r"""A class that implements the protocol for context menu commands. + """A base class that implements the protocol for context menu commands. - These are not created manually, instead they are created via the - decorator or functional interface. + These are not meant to be directly used, same as :class:`ApplicationCommand`. + + This is a subclass of :class:`.Invokable` but does not support the ``parent`` attribute. .. versionadded:: 2.0 Attributes ----------- - name: :class:`str` - The name of the command. - callback: :ref:`coroutine ` - The coroutine that is executed when the command is called. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. guild_only: :class:`bool` Whether the command should only be usable inside a guild. default_member_permissions: :class:`~discord.Permissions` The default permissions a member needs to be able to run the command. - cog: Optional[:class:`Cog`] - The cog that this command belongs to. ``None`` if there isn't one. - checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] - A list of predicates that verifies if the command could be executed - with the given :class:`.ApplicationContext` as the sole parameter. If an exception - is necessary to be thrown to signal failure, then one inherited from - :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then - :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` - event. - cooldown: Optional[:class:`~discord.ext.commands.Cooldown`] - The cooldown applied when the command is invoked. ``None`` if the command - doesn't have a cooldown. name_localizations: Optional[Dict[:class:`str`, :class:`str`]] The name localizations for this command. The values of this should be ``"locale": "name"``. See `here `_ for a list of valid locales. @@ -907,28 +819,19 @@ def to_dict(self) -> Dict[str, Union[str, int]]: class UserCommand(ContextMenuCommand): - r"""A class that implements the protocol for user context menu commands. + """A class that implements the protocol for user context menu commands. These are not created manually, instead they are created via the decorator or functional interface. + This is a subclass of :class:`.Invokable` but does not support the ``parent`` attribute. + + .. versionadded:: 2.0 + Attributes ----------- - name: :class:`str` - The name of the command. - callback: :ref:`coroutine ` - The coroutine that is executed when the command is called. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. - cog: Optional[:class:`.Cog`] - The cog that this command belongs to. ``None`` if there isn't one. - checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] - A list of predicates that verifies if the command could be executed - with the given :class:`.ApplicationContext` as the sole parameter. If an exception - is necessary to be thrown to signal failure, then one inherited from - :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then - :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` - event. """ type = 2 @@ -968,28 +871,19 @@ async def _invoke(self, ctx: ApplicationContext) -> None: class MessageCommand(ContextMenuCommand): - r"""A class that implements the protocol for message context menu commands. + """A class that implements the protocol for message context menu commands. These are not created manually, instead they are created via the decorator or functional interface. + This is a subclass of :class:`.Invokable` but does not support the ``parent`` attribute. + + .. versionadded:: 2.0 + Attributes ----------- - name: :class:`str` - The name of the command. - callback: :ref:`coroutine ` - The coroutine that is executed when the command is called. guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. - cog: Optional[:class:`.Cog`] - The cog that this command belongs to. ``None`` if there isn't one. - checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] - A list of predicates that verifies if the command could be executed - with the given :class:`.ApplicationContext` as the sole parameter. If an exception - is necessary to be thrown to signal failure, then one inherited from - :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then - :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` - event. """ type = 3 @@ -1068,13 +962,11 @@ def message_command(**kwargs): def application_command(cls=SlashCommand, **attrs): """A decorator that transforms a function into an :class:`.ApplicationCommand`. More specifically, - usually one of :class:`.SlashCommand`, :class:`.UserCommand`, or :class:`.MessageCommand`. The exact class + one of :class:`.SlashCommand`, :class:`.UserCommand`, or :class:`.MessageCommand`. The exact class depends on the ``cls`` parameter. - By default, the ``description`` attribute is received automatically from the - docstring of the function and is cleaned up with the use of - ``inspect.cleandoc``. If the docstring is ``bytes``, then it is decoded - into :class:`str` using utf-8 encoding. - The ``name`` attribute also defaults to the function name unchanged. + + The ``description`` and ``name`` of the command are automatically inferred from the function name + and function docstring. .. versionadded:: 2.0 diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 305ef6cf7e..ee3a96439a 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -128,6 +128,53 @@ class _BaseCommand: class BaseContext(abc.Messageable, Generic[BotT]): + """A baseclass to provide ***basic & common functionality*** between + :class:`.ApplicationContext` and :class:`~ext.commands.Context`. + + This is a subclass of :class:`~abc.Messageable` and can be used to + send messages, etc. + + .. versionadded:: 2.2 + + Attributes + ---------- + bot: :class:`.Bot` + The bot that contains the command being executed. + command: Optional[:class:`Invokable`] + The command that is being invoked currently. + args: :class:`list` + The list of transformed arguments that were passed into the command. + If this is accessed during the :func:`.on_command_error` event + then this list could be incomplete. + kwargs: :class:`dict` + A dictionary of transformed arguments that were passed into the command. + Similar to :attr:`args`\, if this is accessed in the + :func:`.on_command_error` event then this dict could be incomplete. + invoked_with: Optional[:class:`str`] + The command name that triggered this invocation. Useful for finding out + which alias called the command. + invoked_parents: List[:class:`str`] + The command names of the parents that triggered this invocation. Useful for + finding out which aliases called the command. + + For example in commands ``?a b c test``, the invoked parents are ``['a', 'b', 'c']``. + invoked_subcommand: Optional[:class:`Invokable`] + The subcommand that was invoked. + If no valid subcommand was invoked then this is equal to ``None``. + subcommand_passed: Optional[:class:`str`] + The string that was attempted to call a subcommand. This does not have + to point to a valid registered subcommand and could just point to a + nonsense string. If nothing was passed to attempt a call to a + subcommand then this is set to ``None``. + + .. note:: + + This will always be ``None`` if accessed on through a slash command. + + command_failed: :class:`bool` + A boolean that indicates if the command failed to be parsed, checked, + or invoked. + """ def __init__( self, bot: Bot, @@ -162,7 +209,7 @@ def __init__( self.command_failed: bool = command_failed async def invoke(self, command: Invokable[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: - r"""|coro| + """|coro| Calls a command with the arguments given. @@ -199,7 +246,7 @@ async def _get_channel(self) -> abc.Messageable: @property def source(self) -> Union[Message, Interaction]: - """Union[:class:`Message`, :class:`Interaction`]: Property to return a message or interaction + """Union[:class:`.Message`, :class:`.Interaction`]: Property to return a message or interaction depending on the context. """ raise NotImplementedError() @@ -230,9 +277,7 @@ def guild_id(self) -> Optional[int]: @utils.cached_property def channel(self) -> MessageableChannel: - """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. - Shorthand for :attr:`.Message.channel`. - """ + """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.""" return self.source.channel @utils.cached_property @@ -242,9 +287,7 @@ def channel_id(self) -> Optional[int]: @utils.cached_property def author(self) -> Union[User, Member]: - """Union[:class:`.User`, :class:`.Member`]: - Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` - """ + """Union[:class:`.User`, :class:`.Member`]: Returns the author associated with this context's command.""" return self.source.author @property @@ -256,7 +299,7 @@ def user(self) -> Union[User, Member]: def me(self) -> Union[Member, ClientUser]: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message - message contexts, or when :meth:`Intents.guilds` is absent. + message contexts, or when :meth:`.Intents.guilds` is absent. """ # bot.user will never be None at this point. return self.guild.me if self.guild and self.guild.me else self.bot.user # type: ignore @@ -268,6 +311,40 @@ def voice_client(self) -> Optional[VoiceProtocol]: class Invokable(Generic[CogT, P, T]): + """A baseclass to provide ***basic & common functionality*** between + :class:`.ApplicationCommand` and :class:`~ext.commands.Command`. + + .. versionadded:: 2.2 + + Attributes + ---------- + name: str + The name of the invokable/command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + parent: Optional[:class:`Invokable`] + The parent group of this command. + cog: Optional[:class:`Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + enabled: :class:`bool` + A boolean that indicates if the command is currently enabled. + If the command is invoked while it is disabled, then + :exc:`.DisabledCommand` is raised to the :func:`.on_command_error` + event. Defaults to ``True``. + checks: List[Callable[[:class:`.BaseContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed with the given + :class:`.BaseContext` (:class:`.ApplicationContext` or :class:`~ext.commands.Context` + to be specific) as the sole parameter. If an exception is necessary to be thrown to + signal failure, then one inherited from :exc:`.CommandError` should be used. Note that + if the checks fail then :exc:`.CheckFailure` exception is raised to the :func:`.on_command_error` + event. + cooldown_after_parsing: :class:`bool` + If ``True``\, cooldown processing is done after argument parsing, + which calls converters. If ``False`` then cooldown processing is done + first and then the converters are called second. Defaults to ``False``. + cooldown: Optional[:class:`Cooldown`] + The cooldown applied when the command is invoked. + """ def __init__(self, func: CallbackT, **kwargs): self.callback: CallbackT = func self.parent: Optional[Invokable] = parent if isinstance((parent := kwargs.get("parent")), _BaseCommand) else None @@ -328,6 +405,24 @@ def callback(self, func: CallbackT) -> None: def cooldown(self): return self._buckets._cooldown + @property + def qualified_name(self) -> str: + """:class:`str`: Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``?one two three`` the qualified name would be + ``one two three``. + """ + if not self.parent: + return self.name + + return f"{self.parent.qualified_name} {self.name}" + + @property + def cog_name(self) -> Optional[str]: + """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" + return type(self.cog).__cog_name__ if self.cog is not None else None + @property def parents(self) -> List[Invokable]: """List[:class:`Invokable`]: Retrieves the parents of this command. @@ -366,24 +461,6 @@ def full_parent_name(self) -> Optional[str]: if self.parent: return self.parent.qualified_name - @property - def qualified_name(self) -> str: - """:class:`str`: Retrieves the fully qualified command name. - - This is the full parent name with the command name as well. - For example, in ``?one two three`` the qualified name would be - ``one two three``. - """ - if not self.parent: - return self.name - - return f"{self.parent.qualified_name} {self.name}" - - @property - def cog_name(self) -> Optional[str]: - """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" - return type(self.cog).__cog_name__ if self.cog is not None else None - def __str__(self) -> str: return self.qualified_name @@ -549,31 +626,6 @@ async def can_run(self, ctx: ContextT) -> bool: finally: ctx.command = original - async def _dispatch_error(self, ctx: ContextT, error: Exception) -> None: - # since I don't want to copy paste code, subclassed Contexts - # dispatch it to their corresponding events - raise NotImplementedError() - - async def dispatch_error(self, ctx: ContextT, error: Exception) -> None: - ctx.command_failed = True - cog = self.cog - - if coro := getattr(self, "on_error", None): - injected = wrap_callback(coro) - if cog is not None: - await injected(cog, ctx, error) - else: - await injected(ctx, error) - - try: - if cog is not None: - local = cog.__class__._get_overridden_method(cog.cog_command_error) - if local is not None: - wrapped = wrap_callback(local) - await wrapped(ctx, error) - finally: - await self._dispatch_error(ctx, error) - def add_check(self, func: Check) -> None: """Adds a check to the command. @@ -604,18 +656,45 @@ def remove_check(self, func: Check) -> None: except ValueError: pass - def _prepare_cooldowns(self, ctx: ContextT): - if not self._buckets.valid: - return + def copy(self): + """Creates a copy of this command. - current = datetime.datetime.now().timestamp() - bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message + Returns + -------- + :class:`Invokable` + A new instance of this command. + """ + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) - if bucket: - retry_after = bucket.update_rate_limit(current) + def _ensure_assignment_on_copy(self, other: Invokable): + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + if self._buckets.valid and not other._buckets.valid: + other._buckets = self._buckets.copy() + if self._max_concurrency != other._max_concurrency: + # _max_concurrency won't be None at this point + other._max_concurrency = self._max_concurrency.copy() # type: ignore - if retry_after: - raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + try: + other.on_error = self.on_error + except AttributeError: + pass + return other + + def _update_copy(self, kwargs: Dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() + + def _set_cog(self, cog: CogT): + self.cog = cog def is_on_cooldown(self, ctx: ContextT) -> bool: """Checks whether the command is currently on cooldown. @@ -678,6 +757,19 @@ def get_cooldown_retry_after(self, ctx) -> float: return 0.0 + def _prepare_cooldowns(self, ctx: ContextT): + if not self._buckets.valid: + return + + current = datetime.datetime.now().timestamp() + bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message + + if bucket: + retry_after = bucket.update_rate_limit(current) + + if retry_after: + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + async def call_before_hooks(self, ctx: ContextT) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: @@ -714,7 +806,6 @@ async def call_after_hooks(self, ctx: ContextT) -> None: # call the cog local hook if applicable: if cog is not None: - # :troll: hook = cog.__class__._get_overridden_method(cog.cog_after_invoke) if hook is not None: await hook(ctx) @@ -752,6 +843,13 @@ async def prepare(self, ctx: ContextT) -> None: raise async def invoke(self, ctx: ContextT) -> None: + """Runs the command with checks. + + Parameters + ---------- + ctx: :class:`.BaseContext` + The context to pass into the command. + """ await self.prepare(ctx) # terminate the invoked_subcommand chain. @@ -779,42 +877,27 @@ async def reinvoke(self, ctx: ContextT, *, call_hooks: bool = False) -> None: if call_hooks: await self.call_after_hooks(ctx) - def copy(self): - """Creates a copy of this command. + async def _dispatch_error(self, ctx: ContextT, error: Exception) -> None: + # since I don't want to copy paste code, subclassed Contexts + # dispatch it to their corresponding events + raise NotImplementedError() - Returns - -------- - :class:`Invokable` - A new instance of this command. - """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) + async def dispatch_error(self, ctx: ContextT, error: Exception) -> None: + ctx.command_failed = True + cog = self.cog - def _ensure_assignment_on_copy(self, other: Invokable): - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - if self._buckets.valid and not other._buckets.valid: - other._buckets = self._buckets.copy() - if self._max_concurrency != other._max_concurrency: - # _max_concurrency won't be None at this point - other._max_concurrency = self._max_concurrency.copy() # type: ignore + if coro := getattr(self, "on_error", None): + injected = wrap_callback(coro) + if cog is not None: + await injected(cog, ctx, error) + else: + await injected(ctx, error) try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def _update_copy(self, kwargs: Dict[str, Any]): - if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() - - def _set_cog(self, cog: CogT): - self.cog = cog + if cog is not None: + local = cog.__class__._get_overridden_method(cog.cog_command_error) + if local is not None: + wrapped = wrap_callback(local) + await wrapped(ctx, error) + finally: + await self._dispatch_error(ctx, error) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 5d0914bdb4..ac41ea3dd8 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -31,8 +31,7 @@ import discord.abc import discord.utils from discord.message import Message -from ...commands.mixins import BaseContext -from ...commands.context import ApplicationContext +from ...commands import BaseContext, ApplicationContext if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -70,16 +69,6 @@ class Context(BaseContext, Generic[BotT]): ----------- message: :class:`.Message` The message that triggered the command being executed. - bot: :class:`.Bot` - The bot that contains the command being executed. - args: :class:`list` - The list of transformed arguments that were passed into the command. - If this is accessed during the :func:`.on_command_error` event - then this list could be incomplete. - kwargs: :class:`dict` - A dictionary of transformed arguments that were passed into the command. - Similar to :attr:`args`\, if this is accessed in the - :func:`.on_command_error` event then this dict could be incomplete. current_parameter: Optional[:class:`inspect.Parameter`] The parameter that is currently being inspected and converted. This is only of use for within converters. @@ -87,27 +76,6 @@ class Context(BaseContext, Generic[BotT]): .. versionadded:: 2.0 prefix: Optional[:class:`str`] The prefix that was used to invoke the command. - command: Optional[:class:`Command`] - The command that is being invoked currently. - invoked_with: Optional[:class:`str`] - The command name that triggered this invocation. Useful for finding out - which alias called the command. - invoked_parents: List[:class:`str`] - The command names of the parents that triggered this invocation. Useful for - finding out which aliases called the command. - - For example in commands ``?a b c test``, the invoked parents are ``['a', 'b', 'c']``. - - .. versionadded:: 1.7 - - invoked_subcommand: Optional[:class:`Command`] - The subcommand that was invoked. - If no valid subcommand was invoked then this is equal to ``None``. - subcommand_passed: Optional[:class:`str`] - The string that was attempted to call a subcommand. This does not have - to point to a valid registered subcommand and could just point to a - nonsense string. If nothing was passed to attempt a call to a - subcommand then this is set to ``None``. command_failed: :class:`bool` A boolean that indicates if the command failed to be parsed, checked, or invoked. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 789aa8c6d0..ed0471c499 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -59,11 +59,9 @@ Invokable, hooked_wrapped_callback, unwrap_function, - wrap_callback, ) from ...enums import ChannelType from ...errors import * -from .cog import Cog from .context import Context from .converter import Greedy, get_converter, run_converters from .cooldowns import ( @@ -184,17 +182,15 @@ def __setitem__(self, k, v): class Command(Invokable, _BaseCommand, Generic[CogT, P, T]): - r"""A class that implements the protocol for a bot text command. + """A class that implements the protocol for a bot text command. These are not created manually, instead they are created via the decorator or functional interface. + This is a subclass of :class:`Invokable`. + Attributes ----------- - name: :class:`str` - The name of the command. - callback: :ref:`coroutine ` - The coroutine that is executed when the command is called. help: Optional[:class:`str`] The long help text for the command. brief: Optional[:class:`str`] @@ -203,23 +199,8 @@ class Command(Invokable, _BaseCommand, Generic[CogT, P, T]): A replacement for arguments in the default help text. aliases: Union[List[:class:`str`], Tuple[:class:`str`]] The list of aliases the command can be invoked under. - enabled: :class:`bool` - A boolean that indicates if the command is currently enabled. - If the command is invoked while it is disabled, then - :exc:`.DisabledCommand` is raised to the :func:`.on_command_error` - event. Defaults to ``True``. parent: Optional[:class:`Group`] - The parent group that this command belongs to. ``None`` if there - isn't one. - cog: Optional[:class:`Cog`] - The cog that this command belongs to. ``None`` if there isn't one. - checks: List[Callable[[:class:`.Context`], :class:`bool`]] - A list of predicates that verifies if the command could be executed - with the given :class:`.Context` as the sole parameter. If an exception - is necessary to be thrown to signal failure, then one inherited from - :exc:`.CommandError` should be used. Note that if the checks fail then - :exc:`.CheckFailure` exception is raised to the :func:`.on_command_error` - event. + The parent group that this command belongs to. description: :class:`str` The message prefixed into the default help command. hidden: :class:`bool` @@ -232,8 +213,6 @@ class Command(Invokable, _BaseCommand, Generic[CogT, P, T]): regular matter rather than passing the rest completely raw. If ``True`` then the keyword-only argument will pass in the rest of the arguments in a completely raw matter. Defaults to ``False``. - invoked_subcommand: Optional[:class:`Command`] - The subcommand that was invoked, if any. require_var_positional: :class:`bool` If ``True`` and a variadic positional argument is specified, requires the user to specify at least one argument. Defaults to ``False``. @@ -245,23 +224,12 @@ class Command(Invokable, _BaseCommand, Generic[CogT, P, T]): requirements are met (e.g. ``?foo a b c`` when only expecting ``a`` and ``b``). Otherwise :func:`.on_command_error` and local error handlers are called with :exc:`.TooManyArguments`. Defaults to ``True``. - cooldown_after_parsing: :class:`bool` - If ``True``\, cooldown processing is done after argument parsing, - which calls converters. If ``False`` then cooldown processing is done - first and then the converters are called second. Defaults to ``False``. extras: :class:`dict` A dict of user provided extras to attach to the Command. .. note:: This object may be copied by the library. - - .. versionadded:: 2.0 - - cooldown: Optional[:class:`Cooldown`] - The cooldown applied when the command is invoked. ``None`` if the command - doesn't have a cooldown. - .. versionadded:: 2.0 """ __original_kwargs__: Dict[str, Any] @@ -572,7 +540,6 @@ def signature(self) -> str: def _set_cog(self, cog): self.cog = cog -# TODO: This is a mess class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. diff --git a/docs/api.rst b/docs/api.rst index 8a00a048fe..e4fa1d03b9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -93,20 +93,49 @@ AutoShardedBot .. autoclass:: AutoShardedBot :members: +Commands +--------- + +Invokable +~~~~~~~~~~ +.. attributetable:: discord.commands.Invokable + +.. autoclass:: discord.commands.Invokable + :members: + +BaseContext +~~~~~~~~~~ +.. attributetable:: discord.commands.BaseContext + +.. autoclass:: discord.commands.BaseContext + :members: + Application Commands --------------------- - -Command Permission Decorators +Decorators ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - .. autofunction:: discord.commands.default_permissions :decorator: .. autofunction:: discord.commands.guild_only :decorator: +.. autofunction:: discord.commands.application_command + :decorator: + +.. autofunction:: discord.commands.command + :decorator: + +.. autofunction:: discord.commands.slash_command + :decorator: + +.. autofunction:: discord.commands.user_command + :decorator: + +.. autofunction:: discord.commands.message_command + :decorator: ApplicationCommand ~~~~~~~~~~~~~~~~~~~ @@ -115,12 +144,6 @@ ApplicationCommand .. autoclass:: ApplicationCommand :members: - -.. autofunction:: discord.commands.application_command - :decorator: - -.. autofunction:: discord.commands.command - :decorator: SlashCommand ~~~~~~~~~~~~~ @@ -129,9 +152,6 @@ SlashCommand .. autoclass:: SlashCommand :members: - -.. autofunction:: discord.commands.slash_command - :decorator: SlashCommandGroup ~~~~~~~~~~~~~~~~~~ @@ -141,6 +161,30 @@ SlashCommandGroup .. autoclass:: SlashCommandGroup :members: +ContextMenuCommand +~~~~~~~~~~~~ + +.. attributetable:: ContextMenuCommand + +.. autoclass:: ContextMenuCommand + :members: + +UserCommand +~~~~~~~~~~~~ + +.. attributetable:: UserCommand + +.. autoclass:: UserCommand + :members: + +MessageCommand +~~~~~~~~~~~~~~~ + +.. attributetable:: MessageCommand + +.. autoclass:: MessageCommand + :members: + Option ~~~~~~~ @@ -148,7 +192,7 @@ Option .. autoclass:: Option :members: - + .. autofunction:: discord.commands.option :decorator: @@ -168,28 +212,6 @@ OptionChoice .. autoclass:: OptionChoice :members: -UserCommand -~~~~~~~~~~~~ - -.. attributetable:: UserCommand - -.. autoclass:: UserCommand - :members: - -.. autofunction:: discord.commands.user_command - :decorator: - -MessageCommand -~~~~~~~~~~~~~~~ - -.. attributetable:: MessageCommand - -.. autoclass:: MessageCommand - :members: - -.. autofunction:: discord.commands.message_command - :decorator: - ApplicationContext ~~~~~~~~~~~~~~~~~~~ From b4aed1cd29b97607dd8e442ae3f22301d89a05ec Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 16:09:48 -0400 Subject: [PATCH 18/54] Merge from master --- discord/bot.py | 2 +- discord/commands/core.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index 514944a257..0b4680b310 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -133,7 +133,7 @@ def add_application_command(self, command: ApplicationCommand) -> None: raise TypeError("The provided command is a sub-command of group") if command.cog is MISSING: - command.cog = None + command._set_cog(None) if self._bot.debug_guilds and command.guild_ids is None: command.guild_ids = self._bot.debug_guilds diff --git a/discord/commands/core.py b/discord/commands/core.py index a813b6cbb4..9cf2eaaa8c 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -585,16 +585,10 @@ def to_dict(self) -> Dict: return as_dict - def add_command(self, command: SlashCommand) -> None: - if command.cog is MISSING: - command.cog = self.cog - - self.subcommands.append(command) - def command(self, cls: Type[T] = SlashCommand, **kwargs) -> Callable[[Callable], SlashCommand]: def wrap(func) -> T: command = cls(func, parent=self, **kwargs) - self.add_command(command) + self.subcommands.append(command) return command return wrap From cffb506699710b3d6c676c257614eb85f8453d47 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 16:53:11 -0400 Subject: [PATCH 19/54] revert: the unfix --- discord/bot.py | 2 +- discord/commands/core.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index 0b4680b310..514944a257 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -133,7 +133,7 @@ def add_application_command(self, command: ApplicationCommand) -> None: raise TypeError("The provided command is a sub-command of group") if command.cog is MISSING: - command._set_cog(None) + command.cog = None if self._bot.debug_guilds and command.guild_ids is None: command.guild_ids = self._bot.debug_guilds diff --git a/discord/commands/core.py b/discord/commands/core.py index 9cf2eaaa8c..ed3dbb5d5e 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -473,6 +473,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): return await ctx.interaction.response.send_autocomplete_result(choices=choices) +# TODO: implement with GrouMixin maybe class SlashCommandGroup(ApplicationCommand): """A class that implements the protocol for a slash command group. @@ -585,10 +586,16 @@ def to_dict(self) -> Dict: return as_dict + def add_command(self, command: SlashCommand) -> None: + if command.cog is MISSING: + command.cog = self.cog + + self.subcommands.append(command) + def command(self, cls: Type[T] = SlashCommand, **kwargs) -> Callable[[Callable], SlashCommand]: def wrap(func) -> T: command = cls(func, parent=self, **kwargs) - self.subcommands.append(command) + self.add_command(command) return command return wrap From ae466b5fceb5dace1eb7e3a4ec6fa039b6f82e1d Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 17:12:35 -0400 Subject: [PATCH 20/54] refactor: remove parent checks reference --- discord/commands/mixins.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index ee3a96439a..3bd2679322 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -604,14 +604,6 @@ async def can_run(self, ctx: ContextT) -> bool: if not await ctx.bot.can_run(ctx): raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") - # I personally don't think parent checks should be - # run with the subcommand. It causes confusion, and - # nerfs control for a bit of reduced redundancy - # predicates = self.checks - # if self.parent is not None: - # # parent checks should be run first - # predicates = self.parent.checks + predicates - if (cog := self.cog) and (local_check := cog._get_overridden_method(cog.cog_check)): ret = await utils.maybe_coroutine(local_check, ctx) if not ret: From 8751a2284e83d9ee94940fc211e34fecf089dfbd Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 17:25:35 -0400 Subject: [PATCH 21/54] fix: p* --- discord/commands/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index ed3dbb5d5e..ab2fac81df 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -473,7 +473,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): return await ctx.interaction.response.send_autocomplete_result(choices=choices) -# TODO: implement with GrouMixin maybe +# TODO: implement with GroupMixin maybe class SlashCommandGroup(ApplicationCommand): """A class that implements the protocol for a slash command group. From 2f3d00bfd2ec79c2d766f5432f5016d8495fd295 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 17:33:30 -0400 Subject: [PATCH 22/54] refactor: remove redundant defintions --- discord/commands/core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index ab2fac81df..eec810b3e9 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -79,7 +79,6 @@ from typing_extensions import ParamSpec from .. import Permissions - from ..cog import Cog from .mixins import BaseContext P = ParamSpec("P") @@ -88,8 +87,6 @@ T = TypeVar("T") -CogT = TypeVar("CogT", bound="Cog") -Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) def _validate_names(obj): From 731864e20499a8f74d7902111c218225ce7bad10 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 17:54:39 -0400 Subject: [PATCH 23/54] fix: error hierarchy and docs --- discord/errors.py | 197 ++++++++++++++++++--------------- discord/ext/commands/errors.py | 13 +-- docs/api.rst | 34 +++++- docs/ext/commands/api.rst | 32 ++---- 4 files changed, 145 insertions(+), 131 deletions(-) diff --git a/discord/errors.py b/discord/errors.py index 4c5479bdfb..9a05ca9723 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -62,9 +62,15 @@ "NoEntryPointError", "ExtensionFailed", "ExtensionNotFound", + "CommandError", + "CommandInvokeError", "ApplicationCommandError", - "CheckFailure", "ApplicationCommandInvokeError", + "CheckFailure", + "MaxConcurrencyReached", + "CommandOnCooldown", + "DisabledCommand", + "UserInputError", ) @@ -106,24 +112,6 @@ class ValidationError(DiscordException): pass -def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]: - items: List[Tuple[str, str]] = [] - for k, v in d.items(): - new_key = f"{key}.{k}" if key else k - - if isinstance(v, dict): - try: - _errors: List[Dict[str, Any]] = v["_errors"] - except KeyError: - items.extend(_flatten_error_dict(v, new_key).items()) - else: - items.append((new_key, " ".join(x.get("message", "") for x in _errors))) - else: - items.append((new_key, v)) - - return dict(items) - - class HTTPException(DiscordException): """Exception that's raised when an HTTP request operation fails. @@ -336,13 +324,21 @@ class ApplicationCommandError(CommandError): pass -class CheckFailure(CommandError): - """Exception raised when the predicates in :attr:`.Command.checks` have failed. +class CommandInvokeError(CommandError): + """Exception raised when the command being invoked raised an exception. This inherits from :exc:`CommandError` + + Attributes + ----------- + original: :exc:`Exception` + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. """ - pass + def __init__(self, e: Exception) -> None: + self.original: Exception = e + super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") class ApplicationCommandInvokeError(ApplicationCommandError): @@ -362,6 +358,81 @@ def __init__(self, e: Exception) -> None: super().__init__(f"Application Command raised an exception: {e.__class__.__name__}: {e}") +class CheckFailure(CommandError): + """Exception raised when the predicates in :attr:`.Command.checks` have failed. + + This inherits from :exc:`CommandError` + """ + + pass + + +class MaxConcurrencyReached(CommandError): + """Exception raised when the command being invoked has reached its maximum concurrency. + + This inherits from :exc:`CommandError`. + + Attributes + ------------ + number: :class:`int` + The maximum number of concurrent invokers allowed. + per: :class:`.BucketType` + The bucket type passed to the :func:`.max_concurrency` decorator. + """ + + def __init__(self, number: int, per: BucketType) -> None: + self.number: int = number + self.per: BucketType = per + name = per.name + suffix = f"per {name}" if per.name != "default" else "globally" + plural = "%s times %s" if number > 1 else "%s time %s" + fmt = plural % (number, suffix) + super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") + + +class CommandOnCooldown(CommandError): + """Exception raised when the command being invoked is on cooldown. + + This inherits from :exc:`CommandError` + + Attributes + ----------- + cooldown: :class:`.Cooldown` + A class with attributes ``rate`` and ``per`` similar to the + :func:`.cooldown` decorator. + type: :class:`BucketType` + The type associated with the cooldown. + retry_after: :class:`float` + The amount of seconds to wait before you can retry again. + """ + + def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: + self.cooldown: Cooldown = cooldown + self.retry_after: float = retry_after + self.type: BucketType = type + super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s") + + +class DisabledCommand(CommandError): + """Exception raised when the command being invoked is disabled. + + This inherits from :exc:`CommandError` + """ + + pass + + +class UserInputError(CommandError): + """The base exception type for errors that involve errors + regarding user input. + + This inherits from :exc:`CommandError`. + """ + + pass + + + class ExtensionError(DiscordException): """Base exception for extension related errors. @@ -450,73 +521,19 @@ def __init__(self, name: str) -> None: super().__init__(msg, name=name) -class MaxConcurrencyReached(CommandError): - """Exception raised when the command being invoked has reached its maximum concurrency. - - This inherits from :exc:`CommandError`. - - Attributes - ------------ - number: :class:`int` - The maximum number of concurrent invokers allowed. - per: :class:`.BucketType` - The bucket type passed to the :func:`.max_concurrency` decorator. - """ - - def __init__(self, number: int, per: BucketType) -> None: - self.number: int = number - self.per: BucketType = per - name = per.name - suffix = f"per {name}" if per.name != "default" else "globally" - plural = "%s times %s" if number > 1 else "%s time %s" - fmt = plural % (number, suffix) - super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") - - -class CommandOnCooldown(CommandError): - """Exception raised when the command being invoked is on cooldown. - - This inherits from :exc:`CommandError` - - Attributes - ----------- - cooldown: :class:`.Cooldown` - A class with attributes ``rate`` and ``per`` similar to the - :func:`.cooldown` decorator. - type: :class:`BucketType` - The type associated with the cooldown. - retry_after: :class:`float` - The amount of seconds to wait before you can retry again. - """ - - def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: - self.cooldown: Cooldown = cooldown - self.retry_after: float = retry_after - self.type: BucketType = type - super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s") - - -class CommandInvokeError(CommandError): - """Exception raised when the command being invoked raised an exception. - - This inherits from :exc:`CommandError` - - Attributes - ----------- - original: :exc:`Exception` - The original exception that was raised. You can also get this via - the ``__cause__`` attribute. - """ - - def __init__(self, e: Exception) -> None: - self.original: Exception = e - super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") - - -class DisabledCommand(CommandError): - """Exception raised when the command being invoked is disabled. +def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]: + items: List[Tuple[str, str]] = [] + for k, v in d.items(): + new_key = f"{key}.{k}" if key else k - This inherits from :exc:`CommandError` - """ + if isinstance(v, dict): + try: + _errors: List[Dict[str, Any]] = v["_errors"] + except KeyError: + items.extend(_flatten_error_dict(v, new_key).items()) + else: + items.append((new_key, " ".join(x.get("message", "") for x in _errors))) + else: + items.append((new_key, v)) - pass + return dict(items) diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 13db7cac21..e6d98fee2d 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -34,7 +34,8 @@ CommandOnCooldown, CommandInvokeError, MaxConcurrencyReached, - DisabledCommand + DisabledCommand, + UserInputError ) if TYPE_CHECKING: @@ -123,16 +124,6 @@ def __init__(self, converter: Converter, original: Exception) -> None: self.original: Exception = original -class UserInputError(CommandError): - """The base exception type for errors that involve errors - regarding user input. - - This inherits from :exc:`CommandError`. - """ - - pass - - class CommandNotFound(CommandError): """Exception raised when a command is attempted to be invoked but no command under that name is found. diff --git a/docs/api.rst b/docs/api.rst index e4fa1d03b9..00c8c8630d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5163,13 +5163,31 @@ The following exceptions are thrown by the library. .. autoexception:: discord.opus.OpusNotLoaded +.. autoexception:: discord.CommandError + :members: + .. autoexception:: discord.ApplicationCommandError :members: - + +.. autoexception:: discord.ApplicationCommandInvokeError + :members: + .. autoexception:: discord.CheckFailure :members: - -.. autoexception:: discord.ApplicationCommandInvokeError + +.. autoexception:: discord.DisabledCommand + :members: + +.. autoexception:: discord.CommandInvokeError + :members: + +.. autoexception:: discord.UserInputError + :members: + +.. autoexception:: discord.CommandOnCooldown + :members: + +.. autoexception:: discord.MaxConcurrencyReached :members: .. autoexception:: discord.ExtensionError @@ -5210,9 +5228,15 @@ Exception Hierarchy - :exc:`Forbidden` - :exc:`NotFound` - :exc:`DiscordServerError` - - :exc:`ApplicationCommandError` + - :exc:`CommandError` + - :exc:`ApplicationCommandError` + - :exc:`ApplicationCommandInvokeError` + - :exc:`CommandInvokeError` - :exc:`CheckFailure` - - :exc:`ApplicationCommandInvokeError` + - :exc:`MaxConcurrencyReached` + - :exc:`CommandOnCooldown` + - :exc:`DisabledCommand` + - :exc:`UserInputError` - :exc:`ExtensionError` - :exc:`ExtensionAlreadyLoaded` - :exc:`ExtensionNotLoaded` diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 8d3b4f7c24..717cf24c86 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -486,33 +486,15 @@ Exceptions .. autoexception:: discord.ext.commands.NoPrivateMessage :members: -.. autoexception:: discord.ext.commands.CheckFailure - :members: - .. autoexception:: discord.ext.commands.CheckAnyFailure :members: .. autoexception:: discord.ext.commands.CommandNotFound :members: -.. autoexception:: discord.ext.commands.DisabledCommand - :members: - -.. autoexception:: discord.ext.commands.CommandInvokeError - :members: - .. autoexception:: discord.ext.commands.TooManyArguments :members: -.. autoexception:: discord.ext.commands.UserInputError - :members: - -.. autoexception:: discord.ext.commands.CommandOnCooldown - :members: - -.. autoexception:: discord.ext.commands.MaxConcurrencyReached - :members: - .. autoexception:: discord.ext.commands.NotOwner :members: @@ -604,9 +586,9 @@ Exception Hierarchy .. exception_hierarchy:: - :exc:`~.DiscordException` - - :exc:`~.commands.CommandError` + - :exc:`~.CommandError` - :exc:`~.commands.ConversionError` - - :exc:`~.commands.UserInputError` + - :exc:`~.UserInputError` - :exc:`~.commands.MissingRequiredArgument` - :exc:`~.commands.TooManyArguments` - :exc:`~.commands.BadArgument` @@ -636,7 +618,7 @@ Exception Hierarchy - :exc:`~.commands.InvalidEndOfQuotedStringError` - :exc:`~.commands.ExpectedClosingQuoteError` - :exc:`~.commands.CommandNotFound` - - :exc:`~.commands.CheckFailure` + - :exc:`~.CheckFailure` - :exc:`~.commands.CheckAnyFailure` - :exc:`~.commands.PrivateMessageOnly` - :exc:`~.commands.NoPrivateMessage` @@ -648,9 +630,9 @@ Exception Hierarchy - :exc:`~.commands.MissingAnyRole` - :exc:`~.commands.BotMissingAnyRole` - :exc:`~.commands.NSFWChannelRequired` - - :exc:`~.commands.DisabledCommand` - - :exc:`~.commands.CommandInvokeError` - - :exc:`~.commands.CommandOnCooldown` - - :exc:`~.commands.MaxConcurrencyReached` + - :exc:`~.DisabledCommand` + - :exc:`~.CommandInvokeError` + - :exc:`~.CommandOnCooldown` + - :exc:`~.MaxConcurrencyReached` - :exc:`~.ClientException` - :exc:`~.commands.CommandRegistrationError` From 23bdc509f71ed7f705bc0f338565c24bfe335c9c Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 29 Aug 2022 19:04:15 -0400 Subject: [PATCH 24/54] Smaller fixes --- discord/commands/core.py | 4 ++-- discord/ext/commands/core.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index eec810b3e9..1fd974fb20 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -106,8 +106,8 @@ def _validate_descriptions(obj): class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): """Base class for all Application Commands, including: - - :class:`.SlashCommand` - - :class:`.SlashCommandGroup` + - :class:`SlashCommand` + - :class:`SlashCommandGroup` - :class:`ContextMenuCommand` which in turn is a superclass of - :class:`MessageCommand` and - :class:`UserCommand` diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index ed0471c499..71e2d61165 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -537,8 +537,6 @@ def signature(self) -> str: return " ".join(result) - def _set_cog(self, cog): - self.cog = cog class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave From 5230946dcf150edec44630356aa38b3f05550244 Mon Sep 17 00:00:00 2001 From: Middledot Date: Tue, 30 Aug 2022 15:00:37 -0400 Subject: [PATCH 25/54] Uncommit unrelated things --- discord/ext/commands/bot.py | 4 +++- discord/ext/commands/context.py | 1 + discord/ext/commands/converter.py | 2 ++ discord/webhook/sync.py | 2 ++ 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 75b82f32b6..c3af75c236 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -40,9 +40,11 @@ from .view import StringView if TYPE_CHECKING: + import importlib.machinery + from discord.message import Message - from ._types import CoroFunc + from ._types import Check, CoroFunc __all__ = ( "when_mentioned", diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index ac41ea3dd8..3ca95d269e 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -40,6 +40,7 @@ from .cog import Cog from .core import Command from .view import StringView + from .help import HelpCommand __all__ = ("Context",) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index fc5e795942..5ec807f417 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -464,6 +464,8 @@ def check(c): @staticmethod def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: + bot = ctx.bot + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index e5ff135d50..459c174209 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -45,6 +45,8 @@ Literal, Optional, Tuple, + Type, + TypeVar, Union, overload, ) From 62077e8b774780577ed77d82cafe58aa661de834 Mon Sep 17 00:00:00 2001 From: Middledot Date: Tue, 30 Aug 2022 15:06:57 -0400 Subject: [PATCH 26/54] Make more concise imports --- discord/commands/__init__.py | 2 +- discord/commands/mixins.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py index 7a91ea8f15..21e54de029 100644 --- a/discord/commands/__init__.py +++ b/discord/commands/__init__.py @@ -27,4 +27,4 @@ from .core import * from .options import * from .permissions import * -from .mixins import Invokable, _BaseCommand, BaseContext +from .mixins import * diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 3bd2679322..1073d6d87b 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -72,6 +72,13 @@ HookT = TypeVar("HookT", bound="Hook") +__all__ = ( + "Invokable", + "_BaseCommand", + "BaseContext", +) + + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: partial = functools.partial while True: From 048d5e2f99b4a9e82157c5d30ee05fbb17d8bffb Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 1 Oct 2022 11:05:04 -0400 Subject: [PATCH 27/54] fix: more cog-setting errors --- discord/commands/core.py | 14 +++++++------- discord/commands/mixins.py | 2 +- discord/ext/commands/core.py | 19 +++++-------------- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index 1fd974fb20..0e2d356b45 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -35,7 +35,6 @@ TYPE_CHECKING, Any, Callable, - Coroutine, Dict, Generator, Generic, @@ -200,8 +199,6 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self.options: List[Option] = kwargs.get("options", []) - self._cog = MISSING - def _validate_parameters(self): params = self._get_signature_parameters() if kwop := self.options: @@ -268,7 +265,7 @@ def _parse_options(self, params, *, check_params: bool = True) -> List[Option]: return final_options - def _match_option_param_names(self, params, options): + def _match_option_param_names(self, params, options: List[Option]): params = self._check_required_params(params) check_annotations: List[Callable[[Option, Type], bool]] = [ @@ -313,8 +310,11 @@ def cog(self): @cog.setter def cog(self, val): - self._cog = val - self._validate_parameters() + if not hasattr(self, "_cog"): + self._cog = MISSING + else: + self._cog = val + self._validate_parameters() @property def is_subcommand(self) -> bool: @@ -584,7 +584,7 @@ def to_dict(self) -> Dict: return as_dict def add_command(self, command: SlashCommand) -> None: - if command.cog is MISSING: + if command.cog is MISSING and not self.cog is None: command.cog = self.cog self.subcommands.append(command) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 1073d6d87b..0dba9b0ac8 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -355,7 +355,7 @@ class Invokable(Generic[CogT, P, T]): def __init__(self, func: CallbackT, **kwargs): self.callback: CallbackT = func self.parent: Optional[Invokable] = parent if isinstance((parent := kwargs.get("parent")), _BaseCommand) else None - self.cog: Optional[CogT] + self.cog: Optional[CogT] = None self.module: Any = None self.name: str = str(kwargs.get("name", func.__name__)) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index f7e4f7762e..dd7f7fa10c 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -582,7 +582,7 @@ def add_command(self, command: Command[CogT, Any, Any]) -> None: :meth:`~.GroupMixin.group` shortcut decorators are used instead. .. versionchanged:: 1.4 - Raise :exc:`.CommandRegistrationError` instead of generic :exc:`.ClientException` + Raise :exc:`.CommandRegistrationError` instead of generic :exc:`.ClientException` Parameters ----------- @@ -950,7 +950,6 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: @overload # for py 3.10 def command( - name: str = ..., cls: Type[Command[CogT, P, T]] = ..., **attrs: Any, ) -> Callable[ @@ -967,7 +966,6 @@ def command( @overload def command( - name: str = ..., cls: Type[Command[CogT, P, T]] = ..., **attrs: Any, ) -> Callable[ @@ -984,7 +982,6 @@ def command( @overload def command( - name: str = ..., cls: Type[CommandT] = ..., **attrs: Any, ) -> Callable[ @@ -1000,7 +997,7 @@ def command( def command( - name: str = MISSING, cls: Type[CommandT] = MISSING, **attrs: Any + cls: Type[CommandT] = MISSING, **attrs: Any ) -> Callable[ [ Union[ @@ -1024,13 +1021,10 @@ def command( Parameters ----------- - name: :class:`str` - The name to create the command with. By default, this uses the - function name unchanged. cls The class to construct with. By default, this is :class:`.Command`. You usually do not change this. - attrs + \*\*attrs Keyword arguments to pass into the construction of the class denoted by ``cls``. @@ -1050,14 +1044,13 @@ def decorator( ) -> CommandT: if isinstance(func, Command): raise TypeError("Callback is already a command.") - return cls(func, name=name, **attrs) + return cls(func, **attrs) return decorator @overload def group( - name: str = ..., cls: Type[Group[CogT, P, T]] = ..., **attrs: Any, ) -> Callable[ @@ -1074,7 +1067,6 @@ def group( @overload def group( - name: str = ..., cls: Type[GroupT] = ..., **attrs: Any, ) -> Callable[ @@ -1090,7 +1082,6 @@ def group( def group( - name: str = MISSING, cls: Type[GroupT] = MISSING, **attrs: Any, ) -> Callable[ @@ -1112,7 +1103,7 @@ def group( """ if cls is MISSING: cls = Group # type: ignore - return command(name=name, cls=cls, **attrs) # type: ignore + return command(cls=cls, **attrs) # type: ignore def check(predicate: Check) -> Callable[[T], T]: From 098fc0b21fde2752d9cea62b31778e334cb04d75 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 1 Oct 2022 11:16:10 -0400 Subject: [PATCH 28/54] chore: use more generic cog typehints --- discord/cog.py | 47 ++++++++++++++++++++----------------- discord/ext/commands/cog.py | 10 ++++---- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/discord/cog.py b/discord/cog.py index 300c6ecadd..aa4ff1bbf8 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -50,8 +50,9 @@ from . import errors from .commands import ( + Invokable, + BaseContext, ApplicationCommand, - ApplicationContext, SlashCommandGroup, _BaseCommand, ) @@ -140,7 +141,7 @@ async def bar(self, ctx): __cog_name__: str __cog_settings__: Dict[str, Any] - __cog_commands__: List[ApplicationCommand] + __cog_commands__: List[Invokable] __cog_listeners__: List[Tuple[str, str]] __cog_guild_ids__: List[int] @@ -275,7 +276,7 @@ class Cog(metaclass=CogMeta): __cog_name__: ClassVar[str] __cog_settings__: ClassVar[Dict[str, Any]] - __cog_commands__: ClassVar[List[ApplicationCommand]] + __cog_commands__: ClassVar[List[Invokable]] __cog_listeners__: ClassVar[List[Tuple[str, str]]] __cog_guild_ids__: ClassVar[List[int]] @@ -285,7 +286,7 @@ def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT: # To do this, we need to interfere with the Cog creation process. return super().__new__(cls) - def get_commands(self) -> List[ApplicationCommand]: + def get_commands(self) -> List[Invokable]: r""" Returns -------- @@ -313,12 +314,12 @@ def description(self) -> str: def description(self, description: str) -> None: self.__cog_description__ = description - def walk_commands(self) -> Generator[ApplicationCommand, None, None]: + def walk_commands(self) -> Generator[Invokable, None, None]: """An iterator that recursively walks through this cog's commands and subcommands. Yields ------ - Union[:class:`.Command`, :class:`.Group`] + Union[:class:`.Invokable`] A command or group from the cog. """ for command in self.__cog_commands__: @@ -401,7 +402,7 @@ def cog_unload(self) -> None: pass @_cog_special_method - def bot_check_once(self, ctx: ApplicationContext) -> bool: + def bot_check_once(self, ctx: BaseContext) -> bool: """A special method that registers as a :meth:`.Bot.check_once` check. @@ -410,43 +411,45 @@ def bot_check_once(self, ctx: ApplicationContext) -> bool: Parameters ----------- - ctx: :class:`.Context` + ctx: :class:`.BaseContext` The invocation context. """ return True @_cog_special_method - def bot_check(self, ctx: ApplicationContext) -> bool: + def bot_check(self, ctx: BaseContext) -> bool: """A special method that registers as a :meth:`.Bot.check` check. This function **can** be a coroutine and must take a sole parameter, - ``ctx``, to represent the :class:`.Context` or :class:`.ApplicationContext`. + ``ctx``, to represent a subclass of :class:`BaseContext` (either :class:`.Context` + or :class:`.ApplicationContext`). Parameters ----------- - ctx: :class:`.Context` + ctx: :class:`.BaseContext` The invocation context. """ return True @_cog_special_method - def cog_check(self, ctx: ApplicationContext) -> bool: + def cog_check(self, ctx: BaseContext) -> bool: """A special method that registers as a :func:`~discord.ext.commands.check` for every command and subcommand in this cog. This function **can** be a coroutine and must take a sole parameter, - ``ctx``, to represent the :class:`.Context` or :class:`.ApplicationContext`. + ``ctx``, to represent a subclass of :class:`BaseContext` (either :class:`.Context` + or :class:`.ApplicationContext`). Parameters ----------- - ctx: :class:`.Context` + ctx: :class:`.BaseContext` The invocation context. """ return True @_cog_special_method - async def cog_command_error(self, ctx: ApplicationContext, error: Exception) -> None: + async def cog_command_error(self, ctx: BaseContext, error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. @@ -457,7 +460,7 @@ async def cog_command_error(self, ctx: ApplicationContext, error: Exception) -> Parameters ----------- - ctx: :class:`.ApplicationContext` + ctx: :class:`.BaseContext` The invocation context where the error happened. error: :class:`ApplicationCommandError` The error that happened. @@ -465,31 +468,31 @@ async def cog_command_error(self, ctx: ApplicationContext, error: Exception) -> pass @_cog_special_method - async def cog_before_invoke(self, ctx: ApplicationContext) -> None: + async def cog_before_invoke(self, ctx: BaseContext) -> None: """A special method that acts as a cog local pre-invoke hook. - This is similar to :meth:`.ApplicationCommand.before_invoke`. + This is similar to :meth:`.Invokable.before_invoke`. This **must** be a coroutine. Parameters ----------- - ctx: :class:`.ApplicationContext` + ctx: :class:`.BaseContext` The invocation context. """ pass @_cog_special_method - async def cog_after_invoke(self, ctx: ApplicationContext) -> None: + async def cog_after_invoke(self, ctx: BaseContext) -> None: """A special method that acts as a cog local post-invoke hook. - This is similar to :meth:`.ApplicationCommand.after_invoke`. + This is similar to :meth:`.BaseContext.after_invoke`. This **must** be a coroutine. Parameters ----------- - ctx: :class:`.ApplicationContext` + ctx: :class:`.BaseContext` The invocation context. """ pass diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 974e77bd4c..194e81293b 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -29,7 +29,7 @@ import discord from ...cog import Cog -from ...commands import ApplicationCommand, SlashCommandGroup +from ...commands import Invokable, ApplicationCommand, SlashCommandGroup if TYPE_CHECKING: from .core import Command @@ -49,12 +49,12 @@ def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT: # To do this, we need to interfere with the Cog creation process. return super().__new__(cls) - def walk_commands(self) -> Generator[Command, None, None]: + def walk_commands(self) -> Generator[Invokable, None, None]: """An iterator that recursively walks through this cog's commands and subcommands. Yields ------ - Union[:class:`.Command`, :class:`.Group`] + Union[:class:`.Invokable`] A command or group from the cog. """ from .core import GroupMixin @@ -70,11 +70,11 @@ def walk_commands(self) -> Generator[Command, None, None]: else: yield command - def get_commands(self) -> List[Union[ApplicationCommand, Command]]: + def get_commands(self) -> List[Invokable]: r""" Returns -------- - List[Union[:class:`~discord.ApplicationCommand`, :class:`.Command`]] + List[:class:`~discord.Invokable`]] A :class:`list` of commands that are defined inside this cog. .. note:: From 4fe2c6856cfbe1b8cf8ec1759890421f9c22399d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Oct 2022 16:12:55 +0000 Subject: [PATCH 29/54] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- discord/cog.py | 8 +- discord/commands/__init__.py | 2 +- discord/commands/context.py | 6 +- discord/commands/cooldowns.py | 48 +++++--- discord/commands/core.py | 29 +++-- discord/commands/mixins.py | 209 +++++++++++++++++--------------- discord/errors.py | 35 +++--- discord/ext/commands/cog.py | 2 +- discord/ext/commands/context.py | 8 +- discord/ext/commands/core.py | 26 +--- discord/ext/commands/errors.py | 8 +- 11 files changed, 202 insertions(+), 179 deletions(-) diff --git a/discord/cog.py b/discord/cog.py index 39c81d0064..b2f9865c2a 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -36,9 +36,9 @@ from . import errors from .commands import ( - Invokable, - BaseContext, ApplicationCommand, + BaseContext, + Invokable, SlashCommandGroup, _BaseCommand, ) @@ -475,9 +475,7 @@ def cog_check(self, ctx: BaseContext) -> bool: return True @_cog_special_method - async def cog_command_error( - self, ctx: BaseContext, error: Exception - ) -> None: + async def cog_command_error(self, ctx: BaseContext, error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py index 21e54de029..259500cf34 100644 --- a/discord/commands/__init__.py +++ b/discord/commands/__init__.py @@ -25,6 +25,6 @@ from .context import * from .core import * +from .mixins import * from .options import * from .permissions import * -from .mixins import * diff --git a/discord/commands/context.py b/discord/commands/context.py index 4a62dce361..05bc5d2546 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -29,6 +29,7 @@ import discord.abc from discord.interactions import Interaction, InteractionMessage, InteractionResponse from discord.webhook.async_ import Webhook + from .mixins import BaseContext if TYPE_CHECKING: @@ -78,6 +79,7 @@ class ApplicationContext(BaseContext): command: :class:`.ApplicationCommand` The command that this context belongs to. """ + command: Optional[ApplicationCommand] def __init__( @@ -88,7 +90,7 @@ def __init__( command: Optional[ApplicationCommand] = None, args: List[Any] = None, kwargs: Dict[str, Any] = None, - **kwargs2 + **kwargs2, ): super().__init__(bot=bot, command=command, args=args, kwargs=kwargs, **kwargs2) @@ -125,7 +127,7 @@ async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> N The default is to start where we left off. Raises - ------- + ------ ValueError The context to reinvoke is not valid. """ diff --git a/discord/commands/cooldowns.py b/discord/commands/cooldowns.py index 3678d45e7c..462874a846 100644 --- a/discord/commands/cooldowns.py +++ b/discord/commands/cooldowns.py @@ -85,7 +85,7 @@ class Cooldown: """Represents a cooldown for a command. Attributes - ----------- + ---------- rate: :class:`int` The total number of tokens available per :attr:`per` seconds. per: :class:`float` @@ -101,17 +101,17 @@ def __init__(self, rate: float, per: float) -> None: self._tokens: int = self.rate self._last: float = 0.0 - def get_tokens(self, current: Optional[float] = None) -> int: + def get_tokens(self, current: float | None = None) -> int: """Returns the number of available tokens before rate limiting is applied. Parameters - ------------ + ---------- current: Optional[:class:`float`] The time in seconds since Unix epoch to calculate tokens at. If not supplied then :func:`time.time()` is used. Returns - -------- + ------- :class:`int` The number of tokens available before the cooldown is to be applied. """ @@ -124,11 +124,11 @@ def get_tokens(self, current: Optional[float] = None) -> int: tokens = self.rate return tokens - def get_retry_after(self, current: Optional[float] = None) -> float: + def get_retry_after(self, current: float | None = None) -> float: """Returns the time in seconds until the cooldown will be reset. Parameters - ------------- + ---------- current: Optional[:class:`float`] The current time in seconds since Unix epoch. If not supplied, then :func:`time.time()` is used. @@ -146,11 +146,11 @@ def get_retry_after(self, current: Optional[float] = None) -> float: return 0.0 - def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: + def update_rate_limit(self, current: float | None = None) -> float | None: """Updates the cooldown rate limit. Parameters - ------------- + ---------- current: Optional[:class:`float`] The time in seconds since Unix epoch to update the rate limit at. If not supplied, then :func:`time.time()` is used. @@ -185,7 +185,7 @@ def copy(self) -> Cooldown: """Creates a copy of this cooldown. Returns - -------- + ------- :class:`Cooldown` A new instance of this cooldown. """ @@ -198,14 +198,14 @@ def __repr__(self) -> str: class CooldownMapping: def __init__( self, - original: Optional[Cooldown], + original: Cooldown | None, type: Callable[[Message], Any], ) -> None: if not callable(type): raise TypeError("Cooldown type must be a BucketType or callable") - self._cache: Dict[Any, Cooldown] = {} - self._cooldown: Optional[Cooldown] = original + self._cache: dict[Any, Cooldown] = {} + self._cooldown: Cooldown | None = original self._type: Callable[[Message], Any] = type def copy(self) -> CooldownMapping: @@ -222,13 +222,13 @@ def type(self) -> Callable[[Message], Any]: return self._type @classmethod - def from_cooldown(cls: Type[C], rate, per, type) -> C: + def from_cooldown(cls: type[C], rate, per, type) -> C: return cls(Cooldown(rate, per), type) def _bucket_key(self, msg: Message) -> Any: return self._type(msg) - def _verify_cache_integrity(self, current: Optional[float] = None) -> None: + def _verify_cache_integrity(self, current: float | None = None) -> None: # we want to delete all cache objects that haven't been used # in a cooldown window. e.g. if we have a command that has a # cooldown of 60s, and it has not been used in 60s then that key should be deleted @@ -240,7 +240,7 @@ def _verify_cache_integrity(self, current: Optional[float] = None) -> None: def create_bucket(self, message: Message) -> Cooldown: return self._cooldown.copy() # type: ignore - def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown: + def get_bucket(self, message: Message, current: float | None = None) -> Cooldown: if self._type is BucketType.default: return self._cooldown # type: ignore @@ -255,13 +255,17 @@ def get_bucket(self, message: Message, current: Optional[float] = None) -> Coold return bucket - def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: + def update_rate_limit( + self, message: Message, current: float | None = None + ) -> float | None: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) class DynamicCooldownMapping(CooldownMapping): - def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None: + def __init__( + self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] + ) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory @@ -342,7 +346,7 @@ class MaxConcurrency: __slots__ = ("number", "per", "wait", "_mapping") def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: - self._mapping: Dict[Any, _Semaphore] = {} + self._mapping: dict[Any, _Semaphore] = {} self.per: BucketType = per self.number: int = number self.wait: bool = wait @@ -351,13 +355,17 @@ def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: raise ValueError("max_concurrency 'number' cannot be less than 1") if not isinstance(per, BucketType): - raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") + raise TypeError( + f"max_concurrency 'per' must be of type BucketType not {type(per)!r}" + ) def copy(self: MC) -> MC: return self.__class__(self.number, per=self.per, wait=self.wait) def __repr__(self) -> str: - return f"" + return ( + f"" + ) def get_key(self, message: Message) -> Any: return self.per.get_key(message) diff --git a/discord/commands/core.py b/discord/commands/core.py index de668a27cb..4b875562bd 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -52,9 +52,9 @@ from ..role import Role from ..threads import Thread from ..user import User -from ..utils import find, MISSING -from .mixins import Invokable, _BaseCommand, CogT +from ..utils import MISSING, find from .context import ApplicationContext, AutocompleteContext +from .mixins import CogT, Invokable, _BaseCommand from .options import Option, OptionChoice __all__ = ( @@ -107,10 +107,11 @@ class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): - :class:`ContextMenuCommand` which in turn is a superclass of - :class:`MessageCommand` and - :class:`UserCommand` - + This is a subclass of :class:`.Invokable`. """ - __original_kwargs__: Dict[str, Any] + + __original_kwargs__: dict[str, Any] cog = None def __init__(self, func: Callable, **kwargs) -> None: @@ -161,7 +162,7 @@ class SlashCommand(ApplicationCommand): .. versionadded:: 2.0 Attributes - ----------- + ---------- guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. options: List[:class:`Option`] @@ -181,6 +182,7 @@ class SlashCommand(ApplicationCommand): The description localizations for this command. The values of this should be ``"locale": "description"``. See `here `_ for a list of valid locales. """ + type = 1 def __new__(cls, *args, **kwargs) -> SlashCommand: @@ -534,7 +536,7 @@ class SlashCommandGroup(ApplicationCommand): .. versionadded:: 2.0 Attributes - ----------- + ---------- description: Optional[:class:`str`] The description for the command. guild_ids: Optional[List[:class:`int`]] @@ -553,6 +555,7 @@ class SlashCommandGroup(ApplicationCommand): The description localizations for this command. The values of this should be ``"locale": "description"``. See `here `_ for a list of valid locales. """ + __initial_commands__: list[SlashCommand | SlashCommandGroup] type = 1 @@ -708,7 +711,9 @@ def create_subgroup( """ if self.parent is not None: - raise Exception("A command subgroup can only have commands and not any more groups.") + raise Exception( + "A command subgroup can only have commands and not any more groups." + ) sub_command_group = SlashCommandGroup( name, description, guild_ids, parent=self, **kwargs @@ -804,7 +809,7 @@ class ContextMenuCommand(ApplicationCommand): .. versionadded:: 2.0 Attributes - ----------- + ---------- guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. guild_only: :class:`bool` @@ -906,10 +911,11 @@ class UserCommand(ContextMenuCommand): .. versionadded:: 2.0 Attributes - ----------- + ---------- guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. """ + type = 2 def __new__(cls, *args, **kwargs) -> UserCommand: @@ -958,10 +964,11 @@ class MessageCommand(ContextMenuCommand): .. versionadded:: 2.0 Attributes - ----------- + ---------- guild_ids: Optional[List[:class:`int`]] The ids of the guilds where this command will be registered. """ + type = 3 def __new__(cls, *args, **kwargs) -> MessageCommand: @@ -1041,7 +1048,7 @@ def application_command(cls=SlashCommand, **attrs): """A decorator that transforms a function into an :class:`.ApplicationCommand`. More specifically, one of :class:`.SlashCommand`, :class:`.UserCommand`, or :class:`.MessageCommand`. The exact class depends on the ``cls`` parameter. - + The ``description`` and ``name`` of the command are automatically inferred from the function name and function docstring. diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 0dba9b0ac8..69e606433e 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -5,41 +5,41 @@ import functools from typing import ( TYPE_CHECKING, - TypeVar, + Any, Callable, Coroutine, - Optional, - Union, - Generic, - Any, Dict, + Generic, List, + Optional, + TypeVar, + Union, ) -from .. import utils, abc +from .. import abc, utils from ..errors import ( ApplicationCommandError, CheckFailure, CommandError, CommandInvokeError, - DisabledCommand, CommandOnCooldown, + DisabledCommand, ) from .cooldowns import BucketType, CooldownMapping, MaxConcurrency if TYPE_CHECKING: from typing_extensions import ParamSpec - from ..bot import Bot, AutoShardedBot - from ..user import User, ClientUser - from ..member import Member + from ..abc import MessageableChannel + from ..bot import AutoShardedBot, Bot + from ..cog import Cog from ..guild import Guild - from ..message import Message from ..interactions import Interaction - from ..abc import MessageableChannel - from ..voice_client import VoiceProtocol + from ..member import Member + from ..message import Message from ..state import ConnectionState - from ..cog import Cog + from ..user import ClientUser, User + from ..voice_client import VoiceProtocol P = ParamSpec("P") else: @@ -65,10 +65,7 @@ ] ErrorT = TypeVar("ErrorT", bound="Error") -Hook = Union[ - Callable[[CogT, ContextT], Coro[Any]], - Callable[[ContextT], Coro[Any]] -] +Hook = Union[Callable[[CogT, ContextT], Coro[Any]], Callable[[ContextT], Coro[Any]]] HookT = TypeVar("HookT", bound="Hook") @@ -135,7 +132,7 @@ class _BaseCommand: class BaseContext(abc.Messageable, Generic[BotT]): - """A baseclass to provide ***basic & common functionality*** between + r"""A baseclass to provide ***basic & common functionality*** between :class:`.ApplicationContext` and :class:`~ext.commands.Context`. This is a subclass of :class:`~abc.Messageable` and can be used to @@ -182,41 +179,43 @@ class BaseContext(abc.Messageable, Generic[BotT]): A boolean that indicates if the command failed to be parsed, checked, or invoked. """ + def __init__( self, bot: Bot, - command: Optional[Invokable], - args: List[Any] = utils.MISSING, - kwargs: Dict[str, Any] = utils.MISSING, + command: Invokable | None, + args: list[Any] = utils.MISSING, + kwargs: dict[str, Any] = utils.MISSING, *, - invoked_with: Optional[str] = None, - invoked_parents: List[str] = utils.MISSING, - invoked_subcommand: Optional[Invokable] = None, - subcommand_passed: Optional[str] = None, - command_failed: bool = False - + invoked_with: str | None = None, + invoked_parents: list[str] = utils.MISSING, + invoked_subcommand: Invokable | None = None, + subcommand_passed: str | None = None, + command_failed: bool = False, ): self.bot: Bot = bot - self.command: Optional[Invokable] = command - self.args: List[Any] = args or [] - self.kwargs: Dict[str, Any] = kwargs or {} + self.command: Invokable | None = command + self.args: list[Any] = args or [] + self.kwargs: dict[str, Any] = kwargs or {} - self.invoked_with: Optional[str] = invoked_with + self.invoked_with: str | None = invoked_with if not self.invoked_with and command: self.invoked_with = command.name - self.invoked_parents: List[str] = invoked_parents or [] + self.invoked_parents: list[str] = invoked_parents or [] if not self.invoked_parents and command: self.invoked_parents = [i.name for i in command.parents] # This will always be None for slash commands - self.subcommand_passed: Optional[str] = subcommand_passed + self.subcommand_passed: str | None = subcommand_passed - self.invoked_subcommand: Optional[Invokable] = invoked_subcommand + self.invoked_subcommand: Invokable | None = invoked_subcommand self.command_failed: bool = command_failed - async def invoke(self, command: Invokable[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: - """|coro| + async def invoke( + self, command: Invokable[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs + ) -> T: + r"""|coro| Calls a command with the arguments given. @@ -252,7 +251,7 @@ async def _get_channel(self) -> abc.Messageable: return self.channel @property - def source(self) -> Union[Message, Interaction]: + def source(self) -> Message | Interaction: """Union[:class:`.Message`, :class:`.Interaction`]: Property to return a message or interaction depending on the context. """ @@ -263,22 +262,24 @@ def _state(self) -> ConnectionState: return self.source._state @property - def cog(self) -> Optional[Cog]: + def cog(self) -> Cog | None: """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. - None if it does not exist.""" + None if it does not exist. + """ if self.command is None: return None return self.command.cog @utils.cached_property - def guild(self) -> Optional[Guild]: + def guild(self) -> Guild | None: """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. - None if not available.""" + None if not available. + """ return self.source.guild @utils.cached_property - def guild_id(self) -> Optional[int]: + def guild_id(self) -> int | None: """:class:`int`: Returns the ID of the guild associated with this context's command.""" return getattr(self.source, "guild_id", self.guild.id if self.guild else None) @@ -288,22 +289,24 @@ def channel(self) -> MessageableChannel: return self.source.channel @utils.cached_property - def channel_id(self) -> Optional[int]: + def channel_id(self) -> int | None: """:class:`int`: Returns the ID of the channel associated with this context's command.""" - return getattr(self.source, "channel_id", self.channel.id if self.channel else None) + return getattr( + self.source, "channel_id", self.channel.id if self.channel else None + ) @utils.cached_property - def author(self) -> Union[User, Member]: + def author(self) -> User | Member: """Union[:class:`.User`, :class:`.Member`]: Returns the author associated with this context's command.""" return self.source.author @property - def user(self) -> Union[User, Member]: + def user(self) -> User | Member: """Union[:class:`.User`, :class:`.Member`]: Alias for :attr:`BaseContext.author`.""" return self.author @utils.cached_property - def me(self) -> Union[Member, ClientUser]: + def me(self) -> Member | ClientUser: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message message contexts, or when :meth:`.Intents.guilds` is absent. @@ -312,13 +315,13 @@ def me(self) -> Union[Member, ClientUser]: return self.guild.me if self.guild and self.guild.me else self.bot.user # type: ignore @property - def voice_client(self) -> Optional[VoiceProtocol]: - """Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" + def voice_client(self) -> VoiceProtocol | None: + r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" return self.guild.voice_client if self.guild else None class Invokable(Generic[CogT, P, T]): - """A baseclass to provide ***basic & common functionality*** between + r"""A baseclass to provide ***basic & common functionality*** between :class:`.ApplicationCommand` and :class:`~ext.commands.Command`. .. versionadded:: 2.2 @@ -352,10 +355,15 @@ class Invokable(Generic[CogT, P, T]): cooldown: Optional[:class:`Cooldown`] The cooldown applied when the command is invoked. """ + def __init__(self, func: CallbackT, **kwargs): self.callback: CallbackT = func - self.parent: Optional[Invokable] = parent if isinstance((parent := kwargs.get("parent")), _BaseCommand) else None - self.cog: Optional[CogT] = None + self.parent: Invokable | None = ( + parent + if isinstance((parent := kwargs.get("parent")), _BaseCommand) + else None + ) + self.cog: CogT | None = None self.module: Any = None self.name: str = str(kwargs.get("name", func.__name__)) @@ -366,8 +374,10 @@ def __init__(self, func: CallbackT, **kwargs): if checks := getattr(func, "__commands_checks__", []): checks.reverse() - checks += kwargs.get("checks", []) # combine all the checks we find (kwargs or decorator) - self.checks: List[Check] = checks + checks += kwargs.get( + "checks", [] + ) # combine all the checks we find (kwargs or decorator) + self.checks: list[Check] = checks # cooldowns cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) @@ -377,23 +387,27 @@ def __init__(self, func: CallbackT, **kwargs): elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) self._buckets: CooldownMapping = buckets # max concurrency - self._max_concurrency: Optional[MaxConcurrency] = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) + self._max_concurrency: MaxConcurrency | None = getattr( + func, "__commands_max_concurrency__", kwargs.get("max_concurrency") + ) # hooks - self._before_invoke: Optional[Hook] = None + self._before_invoke: Hook | None = None if hook := getattr(func, "__before_invoke__", None): self.before_invoke(hook) - self._after_invoke: Optional[Hook] = None + self._after_invoke: Hook | None = None if hook := getattr(func, "__after_invoke__", None): self.after_invoke(hook) - self.on_error: Optional[Error] + self.on_error: Error | None @property def callback(self) -> CallbackT: @@ -426,12 +440,12 @@ def qualified_name(self) -> str: return f"{self.parent.qualified_name} {self.name}" @property - def cog_name(self) -> Optional[str]: + def cog_name(self) -> str | None: """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" return type(self.cog).__cog_name__ if self.cog is not None else None @property - def parents(self) -> List[Invokable]: + def parents(self) -> list[Invokable]: """List[:class:`Invokable`]: Retrieves the parents of this command. If the command has no parents then it returns an empty :class:`list`. @@ -447,7 +461,7 @@ def parents(self) -> List[Invokable]: return entries @property - def root_parent(self) -> Optional[Invokable]: + def root_parent(self) -> Invokable | None: """Optional[:class:`Invokable`]: Retrieves the root parent of this command. If the command has no parents then it returns ``None``. @@ -459,7 +473,7 @@ def root_parent(self) -> Optional[Invokable]: return self.parents[-1] @property - def full_parent_name(self) -> Optional[str]: + def full_parent_name(self) -> str | None: """:class:`str`: Retrieves the fully qualified parent command name. This the base command name required to execute it. For example, @@ -481,7 +495,6 @@ async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): This bypasses all mechanisms -- including checks, converters, invoke hooks, cooldowns, etc. You must take care to pass the proper arguments and types to this function. - """ if self.cog is not None: return await self.callback(self.cog, ctx, *args, **kwargs) @@ -489,7 +502,7 @@ async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): def update(self, **kwargs: Any) -> None: """Updates the :class:`Command` instance with updated attribute. - + Similar to creating a new instance except it updates the current. """ self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) @@ -502,12 +515,12 @@ def error(self, coro: ErrorT) -> ErrorT: invoked afterwards as the catch-all. Parameters - ----------- + ---------- coro: :ref:`coroutine ` The coroutine to register as the local error handler. Raises - ------- + ------ TypeError The coroutine passed is not actually a coroutine. """ @@ -533,12 +546,12 @@ def before_invoke(self, coro: HookT) -> HookT: See :meth:`.Bot.before_invoke` for more info. Parameters - ----------- + ---------- coro: :ref:`coroutine ` The coroutine to register as the pre-invoke hook. Raises - ------- + ------ TypeError The coroutine passed is not actually a coroutine. """ @@ -560,12 +573,12 @@ def after_invoke(self, coro: HookT) -> HookT: See :meth:`.Bot.after_invoke` for more info. Parameters - ----------- + ---------- coro: :ref:`coroutine ` The coroutine to register as the post-invoke hook. Raises - ------- + ------ TypeError The coroutine passed is not actually a coroutine. """ @@ -586,20 +599,20 @@ async def can_run(self, ctx: ContextT) -> bool: Checks whether the command is disabled or not Parameters - ----------- + ---------- ctx: :class:`.Context` The ctx of the command currently being invoked. - Raises + Returns ------- + :class:`bool` + A boolean indicating if the command can be invoked. + + Raises + ------ :class:`CommandError` Any command error that was raised during a check call will be propagated by this function. - - Returns - -------- - :class:`bool` - A boolean indicating if the command can be invoked. """ if not self.enabled: raise DisabledCommand(f"{self.name} command is disabled") @@ -609,9 +622,13 @@ async def can_run(self, ctx: ContextT) -> bool: try: if not await ctx.bot.can_run(ctx): - raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) - if (cog := self.cog) and (local_check := cog._get_overridden_method(cog.cog_check)): + if (cog := self.cog) and ( + local_check := cog._get_overridden_method(cog.cog_check) + ): ret = await utils.maybe_coroutine(local_check, ctx) if not ret: return False @@ -631,7 +648,7 @@ def add_check(self, func: Check) -> None: This is the non-decorator interface to :func:`.check`. Parameters - ----------- + ---------- func: Callable The function that will be used as a check. """ @@ -645,7 +662,7 @@ def remove_check(self, func: Check) -> None: if the function is not in the command's checks. Parameters - ----------- + ---------- func: Callable The function to remove from the checks. """ @@ -659,7 +676,7 @@ def copy(self): """Creates a copy of this command. Returns - -------- + ------- :class:`Invokable` A new instance of this command. """ @@ -672,10 +689,10 @@ def _ensure_assignment_on_copy(self, other: Invokable): if self.checks != other.checks: other.checks = self.checks.copy() if self._buckets.valid and not other._buckets.valid: - other._buckets = self._buckets.copy() + other._buckets = self._buckets.copy() if self._max_concurrency != other._max_concurrency: - # _max_concurrency won't be None at this point - other._max_concurrency = self._max_concurrency.copy() # type: ignore + # _max_concurrency won't be None at this point + other._max_concurrency = self._max_concurrency.copy() # type: ignore try: other.on_error = self.on_error @@ -683,7 +700,7 @@ def _ensure_assignment_on_copy(self, other: Invokable): pass return other - def _update_copy(self, kwargs: Dict[str, Any]): + def _update_copy(self, kwargs: dict[str, Any]): if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) @@ -703,12 +720,12 @@ def is_on_cooldown(self, ctx: ContextT) -> bool: This uses the current time instead of the interaction time. Parameters - ----------- + ---------- ctx: :class:`.ApplicationContext` The invocation context to use when checking the command's cooldown status. Returns - -------- + ------- :class:`bool` A boolean indicating if the command is on cooldown. """ @@ -723,7 +740,7 @@ def reset_cooldown(self, ctx) -> None: """Resets the cooldown on this command. Parameters - ----------- + ---------- ctx: :class:`.ApplicationContext` The invocation context to reset the cooldown under. """ @@ -739,12 +756,12 @@ def get_cooldown_retry_after(self, ctx) -> float: This uses the current time instead of the interaction time. Parameters - ----------- + ---------- ctx: :class:`.ApplicationContext` The invocation context to retrieve the cooldown from. Returns - -------- + ------- :class:`float` The amount of time left on this command's cooldown in seconds. If this is ``0.0`` then the command isn't on cooldown. @@ -821,7 +838,9 @@ async def prepare(self, ctx: ContextT) -> None: ctx.command = self if not await self.can_run(ctx): - raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") + raise CheckFailure( + f"The check functions for command {self.qualified_name} failed." + ) if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -843,7 +862,7 @@ async def prepare(self, ctx: ContextT) -> None: async def invoke(self, ctx: ContextT) -> None: """Runs the command with checks. - + Parameters ---------- ctx: :class:`.BaseContext` diff --git a/discord/errors.py b/discord/errors.py index 4986698d7d..9e7466f0ca 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -37,8 +37,8 @@ except ModuleNotFoundError: _ResponseType = ClientResponse - from .interactions import Interaction from .commands.cooldowns import BucketType, Cooldown + from .interactions import Interaction __all__ = ( "DiscordException", @@ -273,6 +273,7 @@ def __init__(self, interaction: Interaction): # command errors + class CommandError(DiscordException): r"""The base exception type for all command related errors. @@ -286,7 +287,9 @@ class CommandError(DiscordException): def __init__(self, message: Optional[str] = None, *args: Any) -> None: if message is not None: # clean-up @everyone and @here mentions - m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") + m = message.replace("@everyone", "@\u200beveryone").replace( + "@here", "@\u200bhere" + ) super().__init__(m, *args) else: super().__init__(*args) @@ -301,7 +304,6 @@ class ApplicationCommandError(CommandError): in a special way as they are caught and passed into a special event from :class:`.Bot`\, :func:`.on_command_error`. """ - pass class CommandInvokeError(CommandError): @@ -310,7 +312,7 @@ class CommandInvokeError(CommandError): This inherits from :exc:`CommandError` Attributes - ----------- + ---------- original: :exc:`Exception` The original exception that was raised. You can also get this via the ``__cause__`` attribute. @@ -327,7 +329,7 @@ class ApplicationCommandInvokeError(ApplicationCommandError): This inherits from :exc:`ApplicationCommandError` Attributes - ----------- + ---------- original: :exc:`Exception` The original exception that was raised. You can also get this via the ``__cause__`` attribute. @@ -335,7 +337,9 @@ class ApplicationCommandInvokeError(ApplicationCommandError): def __init__(self, e: Exception) -> None: self.original: Exception = e - super().__init__(f"Application Command raised an exception: {e.__class__.__name__}: {e}") + super().__init__( + f"Application Command raised an exception: {e.__class__.__name__}: {e}" + ) class CheckFailure(CommandError): @@ -344,8 +348,6 @@ class CheckFailure(CommandError): This inherits from :exc:`CommandError` """ - pass - class MaxConcurrencyReached(CommandError): """Exception raised when the command being invoked has reached its maximum concurrency. @@ -353,7 +355,7 @@ class MaxConcurrencyReached(CommandError): This inherits from :exc:`CommandError`. Attributes - ------------ + ---------- number: :class:`int` The maximum number of concurrent invokers allowed. per: :class:`.BucketType` @@ -367,7 +369,9 @@ def __init__(self, number: int, per: BucketType) -> None: suffix = f"per {name}" if per.name != "default" else "globally" plural = "%s times %s" if number > 1 else "%s time %s" fmt = plural % (number, suffix) - super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") + super().__init__( + f"Too many people are using this command. It can only be used {fmt} concurrently." + ) class CommandOnCooldown(CommandError): @@ -376,7 +380,7 @@ class CommandOnCooldown(CommandError): This inherits from :exc:`CommandError` Attributes - ----------- + ---------- cooldown: :class:`.Cooldown` A class with attributes ``rate`` and ``per`` similar to the :func:`.cooldown` decorator. @@ -386,7 +390,9 @@ class CommandOnCooldown(CommandError): The amount of seconds to wait before you can retry again. """ - def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: + def __init__( + self, cooldown: Cooldown, retry_after: float, type: BucketType + ) -> None: self.cooldown: Cooldown = cooldown self.retry_after: float = retry_after self.type: BucketType = type @@ -399,8 +405,6 @@ class DisabledCommand(CommandError): This inherits from :exc:`CommandError` """ - pass - class UserInputError(CommandError): """The base exception type for errors that involve errors @@ -409,9 +413,6 @@ class UserInputError(CommandError): This inherits from :exc:`CommandError`. """ - pass - - class ExtensionError(DiscordException): """Base exception for extension related errors. diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index b52661530c..5f9eacb72a 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -29,7 +29,7 @@ import discord from ...cog import Cog -from ...commands import Invokable, ApplicationCommand, SlashCommandGroup +from ...commands import ApplicationCommand, Invokable, SlashCommandGroup if TYPE_CHECKING: from .core import Command diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 9c7cf36ec7..7fb908c12e 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -31,7 +31,8 @@ import discord.abc import discord.utils from discord.message import Message -from ...commands import BaseContext, ApplicationContext + +from ...commands import ApplicationContext, BaseContext if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -39,8 +40,8 @@ from .bot import AutoShardedBot, Bot from .cog import Cog from .core import Command - from .view import StringView from .help import HelpCommand + from .view import StringView __all__ = ("Context",) @@ -67,7 +68,7 @@ class Context(BaseContext, Generic[BotT]): This class implements the :class:`~discord.abc.Messageable` ABC. Attributes - ----------- + ---------- message: :class:`.Message` The message that triggered the command being executed. current_parameter: Optional[:class:`inspect.Parameter`] @@ -81,6 +82,7 @@ class Context(BaseContext, Generic[BotT]): A boolean that indicates if the command failed to be parsed, checked, or invoked. """ + command: Optional[Command] def __init__( diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 36af5012f3..283f650319 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -48,12 +48,7 @@ slash_command, user_command, ) -from ...commands.mixins import ( - CogT, - Invokable, - hooked_wrapped_callback, - unwrap_function, -) +from ...commands.mixins import CogT, Invokable, hooked_wrapped_callback, unwrap_function from ...enums import ChannelType from ...errors import * from .context import Context @@ -86,14 +81,8 @@ ErrorT = TypeVar("ErrorT", bound="Error") CallbackT = Union[ - Callable[ - [Concatenate[CogT, ContextT, P]], - Coro[T] - ], - Callable[ - [Concatenate[ContextT, P]], - Coro[T] - ], + Callable[[Concatenate[CogT, ContextT, P]], Coro[T]], + Callable[[Concatenate[ContextT, P]], Coro[T]], ] else: P = TypeVar("P") @@ -178,7 +167,7 @@ def __setitem__(self, k, v): class Command(Invokable, _BaseCommand, Generic[CogT, P, T]): - """A class that implements the protocol for a bot text command. + r"""A class that implements the protocol for a bot text command. These are not created manually, instead they are created via the decorator or functional interface. @@ -288,10 +277,7 @@ def callback( return self._callback @callback.setter - def callback( - self, - func: CallbackT - ) -> None: + def callback(self, func: CallbackT) -> None: self._callback = func unwrap = unwrap_function(func) self.module = unwrap.__module__ @@ -1016,7 +1002,7 @@ def command( ], Command[CogT, P, T] | CommandT, ]: - """A decorator that transforms a function into a :class:`.Command` + r"""A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. By default the ``help`` attribute is received automatically from the diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 29ec634e4d..16e988dbc7 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -28,14 +28,14 @@ from typing import TYPE_CHECKING, Any, Callable from discord.errors import ( + CheckFailure, ClientException, CommandError, - CheckFailure, - CommandOnCooldown, CommandInvokeError, - MaxConcurrencyReached, + CommandOnCooldown, DisabledCommand, - UserInputError + MaxConcurrencyReached, + UserInputError, ) if TYPE_CHECKING: From ab8946215f67143be5b7c7230da3cb12b56c2353 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 16 Oct 2022 17:23:39 +0000 Subject: [PATCH 30/54] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- discord/commands/cooldowns.py | 2 +- discord/commands/core.py | 11 +---------- discord/commands/mixins.py | 13 +------------ 3 files changed, 3 insertions(+), 23 deletions(-) diff --git a/discord/commands/cooldowns.py b/discord/commands/cooldowns.py index 462874a846..761b4274fb 100644 --- a/discord/commands/cooldowns.py +++ b/discord/commands/cooldowns.py @@ -28,7 +28,7 @@ import asyncio import time from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar from discord.enums import Enum diff --git a/discord/commands/core.py b/discord/commands/core.py index 4b875562bd..c04980ba71 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -31,16 +31,7 @@ import types from collections import OrderedDict from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Generic, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Generator, Generic, TypeVar, Union from ..channel import _threaded_guild_channel_factory from ..enums import Enum as DiscordEnum diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 69e606433e..bca9db4665 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -3,18 +3,7 @@ import asyncio import datetime import functools -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Coroutine, - Dict, - Generic, - List, - Optional, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar, Union from .. import abc, utils from ..errors import ( From d8b03ee54a52a29c598635a9589b45101d9544af Mon Sep 17 00:00:00 2001 From: BobDotCom <71356958+BobDotCom@users.noreply.github.com> Date: Mon, 17 Oct 2022 14:30:47 -0500 Subject: [PATCH 31/54] Fix typehints --- discord/commands/context.py | 8 ++++---- discord/commands/core.py | 4 ++-- discord/errors.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/discord/commands/context.py b/discord/commands/context.py index 05bc5d2546..7e9791b46d 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -80,16 +80,16 @@ class ApplicationContext(BaseContext): The command that this context belongs to. """ - command: Optional[ApplicationCommand] + command: ApplicationCommand | None def __init__( self, bot: Bot, interaction: Interaction, *, - command: Optional[ApplicationCommand] = None, - args: List[Any] = None, - kwargs: Dict[str, Any] = None, + command: ApplicationCommand | None = None, + args: list[Any] = None, + kwargs: dict[str, Any] = None, **kwargs2, ): super().__init__(bot=bot, command=command, args=args, kwargs=kwargs, **kwargs2) diff --git a/discord/commands/core.py b/discord/commands/core.py index c04980ba71..b40700c309 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -108,7 +108,7 @@ class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): def __init__(self, func: Callable, **kwargs) -> None: super().__init__(func, **kwargs) self.id: int | None = kwargs.get("id") - self.guild_ids: List[int] | None = kwargs.get("guild_ids", None) + self.guild_ids: list[int] | None = kwargs.get("guild_ids", None) # Permissions self.default_member_permissions: Permissions | None = getattr( @@ -278,7 +278,7 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: return final_options - def _match_option_param_names(self, params, options: List[Option]): + def _match_option_param_names(self, params, options: list[Option]): params = self._check_required_params(params) check_annotations: list[Callable[[Option, type], bool]] = [ diff --git a/discord/errors.py b/discord/errors.py index 9e7466f0ca..f60d2763f9 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -284,7 +284,7 @@ class CommandError(DiscordException): from :class:`.Bot`\, :func:`.on_command_error`. """ - def __init__(self, message: Optional[str] = None, *args: Any) -> None: + def __init__(self, message: str | None = None, *args: Any) -> None: if message is not None: # clean-up @everyone and @here mentions m = message.replace("@everyone", "@\u200beveryone").replace( From 3fd4d086309e84d914fc7da7198e55792c3b10ec Mon Sep 17 00:00:00 2001 From: BobDotCom <71356958+BobDotCom@users.noreply.github.com> Date: Mon, 17 Oct 2022 14:34:53 -0500 Subject: [PATCH 32/54] Fix more typehints --- discord/ext/commands/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 7fb908c12e..c5337deddb 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -83,7 +83,7 @@ class Context(BaseContext, Generic[BotT]): or invoked. """ - command: Optional[Command] + command: Command | None def __init__( self, From 88c90ca051b0853026a8049e27a4fb51372424f2 Mon Sep 17 00:00:00 2001 From: Middledot <78228142+Middledot@users.noreply.github.com> Date: Tue, 25 Oct 2022 00:01:29 -0500 Subject: [PATCH 33/54] fix(commands): apply suggestions Co-Authored-By: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Co-authored-by: BobDotCom <71356958+BobDotCom@users.noreply.github.com> --- discord/commands/core.py | 1 - discord/commands/mixins.py | 183 ++++++++++++++++++++++++------------- 2 files changed, 119 insertions(+), 65 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index b40700c309..868d091ea5 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -102,7 +102,6 @@ class ApplicationCommand(Invokable, _BaseCommand, Generic[CogT, P, T]): This is a subclass of :class:`.Invokable`. """ - __original_kwargs__: dict[str, Any] cog = None def __init__(self, func: Callable, **kwargs) -> None: diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index bca9db4665..09195e379f 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -1,3 +1,28 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + from __future__ import annotations import asyncio @@ -14,10 +39,10 @@ CommandOnCooldown, DisabledCommand, ) -from .cooldowns import BucketType, CooldownMapping, MaxConcurrency +from .cooldowns import BucketType, CooldownMapping, MaxConcurrency, Cooldown if TYPE_CHECKING: - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Concatenate from ..abc import MessageableChannel from ..bot import AutoShardedBot, Bot @@ -34,18 +59,18 @@ else: P = TypeVar("P") + BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") CogT = TypeVar("CogT", bound="Cog") -CallbackT = TypeVar("CallbackT") -ContextT = TypeVar("ContextT", bound="BaseContext") T = TypeVar("T") Coro = Coroutine[Any, Any, T] +Callback = Callable[Concatenate[CogT, "BaseContext", P], Coro[T]] | Callable[Concatenate["BaseContext", P], Coro[T]] MaybeCoro = Union[T, Coro[T]] Check = Union[ - Callable[[CogT, ContextT], MaybeCoro[bool]], - Callable[[ContextT], MaybeCoro[bool]], + Callable[[CogT, "BaseContext"], MaybeCoro[bool]], + Callable[["BaseContext"], MaybeCoro[bool]], ] Error = Union[ @@ -54,7 +79,7 @@ ] ErrorT = TypeVar("ErrorT", bound="Error") -Hook = Union[Callable[[CogT, ContextT], Coro[Any]], Callable[[ContextT], Coro[Any]]] +Hook = Union[Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]]] HookT = TypeVar("HookT", bound="Hook") @@ -65,18 +90,17 @@ ) -def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: - partial = functools.partial +def unwrap_function(function: functools.partial | Callable) -> Callback: while True: if hasattr(function, "__wrapped__"): - function = function.__wrapped__ - elif isinstance(function, partial): + function = function.__wrapped__ # type: ignore # function may or may not have attribute + elif isinstance(function, functools.partial): function = function.func else: return function -def wrap_callback(coro): +def wrap_callback(coro: Callback): @functools.wraps(coro) async def wrapped(*args, **kwargs): try: @@ -92,9 +116,9 @@ async def wrapped(*args, **kwargs): return wrapped -def hooked_wrapped_callback(command: Invokable, ctx: ContextT, coro: CallbackT): +def hook_wrapped_callback(command: Invokable, ctx: BaseContext, coro: Callback): @functools.wraps(coro) - async def wrapped(*args, **kwargs): + async def wrapper(*args, **kwargs): try: ret = await coro(*args, **kwargs) except (ApplicationCommandError, CommandError): @@ -113,7 +137,7 @@ async def wrapped(*args, **kwargs): return ret - return wrapped + return wrapper class _BaseCommand: @@ -121,7 +145,7 @@ class _BaseCommand: class BaseContext(abc.Messageable, Generic[BotT]): - r"""A baseclass to provide ***basic & common functionality*** between + r"""A base class to provide ***basic & common functionality*** between :class:`.ApplicationContext` and :class:`~ext.commands.Context`. This is a subclass of :class:`~abc.Messageable` and can be used to @@ -206,7 +230,7 @@ async def invoke( ) -> T: r"""|coro| - Calls a command with the arguments given. + Invokes a command with the arguments given. This is useful if you want to just call the callback that a :class:`.Invokable` holds internally. @@ -244,7 +268,7 @@ def source(self) -> Message | Interaction: """Union[:class:`.Message`, :class:`.Interaction`]: Property to return a message or interaction depending on the context. """ - raise NotImplementedError() + raise NotImplementedError @property def _state(self) -> ConnectionState: @@ -262,14 +286,12 @@ def cog(self) -> Cog | None: @utils.cached_property def guild(self) -> Guild | None: - """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. - None if not available. - """ + """Optional[:class:`.Guild`]: Returns the guild associated with this context's command.""" return self.source.guild @utils.cached_property def guild_id(self) -> int | None: - """:class:`int`: Returns the ID of the guild associated with this context's command.""" + """Optional[:class:`int`]: Returns the ID of the guild associated with this context's command.""" return getattr(self.source, "guild_id", self.guild.id if self.guild else None) @utils.cached_property @@ -279,7 +301,7 @@ def channel(self) -> MessageableChannel: @utils.cached_property def channel_id(self) -> int | None: - """:class:`int`: Returns the ID of the channel associated with this context's command.""" + """Optional[:class:`int`]: Returns the ID of the channel associated with this context's command.""" return getattr( self.source, "channel_id", self.channel.id if self.channel else None ) @@ -296,9 +318,8 @@ def user(self) -> User | Member: @utils.cached_property def me(self) -> Member | ClientUser: - """Union[:class:`.Member`, :class:`.ClientUser`]: - Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message - message contexts, or when :meth:`.Intents.guilds` is absent. + """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the + :class:`.ClientUser` in private message contexts, or when :meth:`.Intents.guilds` is absent. """ # bot.user will never be None at this point. return self.guild.me if self.guild and self.guild.me else self.bot.user # type: ignore @@ -310,7 +331,7 @@ def voice_client(self) -> VoiceProtocol | None: class Invokable(Generic[CogT, P, T]): - r"""A baseclass to provide ***basic & common functionality*** between + r"""A base class to provide ***basic & common functionality*** between :class:`.ApplicationCommand` and :class:`~ext.commands.Command`. .. versionadded:: 2.2 @@ -344,32 +365,48 @@ class Invokable(Generic[CogT, P, T]): cooldown: Optional[:class:`Cooldown`] The cooldown applied when the command is invoked. """ + __original_kwargs__: dict[str, Any] - def __init__(self, func: CallbackT, **kwargs): - self.callback: CallbackT = func + def __new__(cls, *args, **kwargs) -> Invokable: + self = super().__new__(cls) + + self.__original_kwargs__ = kwargs.copy() + return self + + def __init__( + self, + func: Callback, + name: str | None = None, + enabled: bool = False, + cooldown_after_parsing: bool = False, + parent: Invokable | None = None, + checks: list[Check] = [], + cooldown: CooldownMapping | None = None, + max_concurrency: MaxConcurrency | None = None, + ): + self.callback: Callback = func self.parent: Invokable | None = ( parent - if isinstance((parent := kwargs.get("parent")), _BaseCommand) + if isinstance(parent, _BaseCommand) else None ) self.cog: CogT | None = None self.module: Any = None - self.name: str = str(kwargs.get("name", func.__name__)) - self.enabled: bool = kwargs.get("enabled", True) - self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) + self.name: str = str(name or func.__name__) + self.enabled: bool = enabled + self.cooldown_after_parsing: bool = cooldown_after_parsing # checks - if checks := getattr(func, "__commands_checks__", []): - checks.reverse() + if _checks := getattr(func, "__commands_checks__", []): + # combine all that we find (kwargs or decorator) + _checks.reverse() + checks += _checks - checks += kwargs.get( - "checks", [] - ) # combine all the checks we find (kwargs or decorator) self.checks: list[Check] = checks # cooldowns - cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) + cooldown = getattr(func, "__commands_cooldown__", cooldown) if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) @@ -384,7 +421,7 @@ def __init__(self, func: CallbackT, **kwargs): # max concurrency self._max_concurrency: MaxConcurrency | None = getattr( - func, "__commands_max_concurrency__", kwargs.get("max_concurrency") + func, "__commands_max_concurrency__", max_concurrency ) # hooks @@ -399,11 +436,12 @@ def __init__(self, func: CallbackT, **kwargs): self.on_error: Error | None @property - def callback(self) -> CallbackT: + def callback(self) -> Callback: + """Returns the command's callback.""" return self._callback @callback.setter - def callback(self, func: CallbackT) -> None: + def callback(self, func: Callback) -> None: if not asyncio.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") @@ -412,15 +450,16 @@ def callback(self, func: CallbackT) -> None: self.module = unwrap.__module__ @property - def cooldown(self): + def cooldown(self) -> Cooldown | None: + """Returns the cooldown for the command.""" return self._buckets._cooldown @property def qualified_name(self) -> str: """:class:`str`: Retrieves the fully qualified command name. - This is the full parent name with the command name as well. - For example, in ``?one two three`` the qualified name would be + This is the full name of the parent command with the subcommand name as well. + For example, in ``?one two three``, the qualified name would be ``one two three``. """ if not self.parent: @@ -443,8 +482,8 @@ def parents(self) -> list[Invokable]: """ entries = [] command = self - while command.parent is not None: # type: ignore - command = command.parent # type: ignore + while command.parent is not None: + command = command.parent entries.append(command) return entries @@ -474,7 +513,7 @@ def full_parent_name(self) -> str | None: def __str__(self) -> str: return self.qualified_name - async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): + async def __call__(self, ctx: BaseContext, *args: P.args, **kwargs: P.kwargs): """|coro| Calls the internal callback that the command holds. @@ -485,12 +524,13 @@ async def __call__(self, ctx: ContextT, *args: P.args, **kwargs: P.kwargs): invoke hooks, cooldowns, etc. You must take care to pass the proper arguments and types to this function. """ + new_args = (ctx, *args) if self.cog is not None: - return await self.callback(self.cog, ctx, *args, **kwargs) - return await self.callback(ctx, *args, **kwargs) + new_args = (self.cog, *args) + return await self.callback(*new_args, **kwargs) def update(self, **kwargs: Any) -> None: - """Updates the :class:`Command` instance with updated attribute. + """Updates the :class:`Invokable` instance with updated attribute. Similar to creating a new instance except it updates the current. """ @@ -577,7 +617,7 @@ def after_invoke(self, coro: HookT) -> HookT: self._after_invoke = coro return coro - async def can_run(self, ctx: ContextT) -> bool: + async def can_run(self, ctx: BaseContext) -> bool: """|coro| Checks if the command can be executed by checking all the predicates @@ -701,7 +741,7 @@ def _update_copy(self, kwargs: dict[str, Any]): def _set_cog(self, cog: CogT): self.cog = cog - def is_on_cooldown(self, ctx: ContextT) -> bool: + def is_on_cooldown(self, ctx: BaseContext) -> bool: """Checks whether the command is currently on cooldown. .. note:: @@ -762,7 +802,7 @@ def get_cooldown_retry_after(self, ctx) -> float: return 0.0 - def _prepare_cooldowns(self, ctx: ContextT): + def _prepare_cooldowns(self, ctx: BaseContext): if not self._buckets.valid: return @@ -775,7 +815,7 @@ def _prepare_cooldowns(self, ctx: ContextT): if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - async def call_before_hooks(self, ctx: ContextT) -> None: + async def call_before_hooks(self, ctx: BaseContext) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog @@ -800,7 +840,7 @@ async def call_before_hooks(self, ctx: ContextT) -> None: if hook is not None: await hook(ctx) - async def call_after_hooks(self, ctx: ContextT) -> None: + async def call_after_hooks(self, ctx: BaseContext) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, "__self__", cog) @@ -819,11 +859,11 @@ async def call_after_hooks(self, ctx: ContextT) -> None: if hook is not None: await hook(ctx) - async def _parse_arguments(self, ctx: ContextT) -> None: + async def _parse_arguments(self, ctx: BaseContext) -> None: """Parses arguments and attaches them to the context class (Union[:class:`~ext.commands.Context`, :class:`.ApplicationContext`])""" - raise NotImplementedError() + raise NotImplementedError - async def prepare(self, ctx: ContextT) -> None: + async def prepare(self, ctx: BaseContext) -> None: ctx.command = self if not await self.can_run(ctx): @@ -849,7 +889,7 @@ async def prepare(self, ctx: ContextT) -> None: await self._max_concurrency.release(ctx) # type: ignore raise - async def invoke(self, ctx: ContextT) -> None: + async def invoke(self, ctx: BaseContext) -> None: """Runs the command with checks. Parameters @@ -867,7 +907,22 @@ async def invoke(self, ctx: ContextT) -> None: injected = hooked_wrapped_callback(self, ctx, self.callback) await injected(*ctx.args, **ctx.kwargs) - async def reinvoke(self, ctx: ContextT, *, call_hooks: bool = False) -> None: + async def reinvoke(self, ctx: BaseContext, *, call_hooks: bool = False) -> None: + """|coro| + + Calls the command again. + + This is similar to :meth:`Invokable.invoke` except that it bypasses + checks, cooldowns, and error handlers. + + Parameters + ---------- + ctx: BaseContext + The context to invoke with. + call_hooks: :class:`bool` + Whether to call the before and after invoke hooks. + """ + ctx.command = self await self._parse_arguments(ctx) @@ -884,12 +939,12 @@ async def reinvoke(self, ctx: ContextT, *, call_hooks: bool = False) -> None: if call_hooks: await self.call_after_hooks(ctx) - async def _dispatch_error(self, ctx: ContextT, error: Exception) -> None: + async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: # since I don't want to copy paste code, subclassed Contexts # dispatch it to their corresponding events - raise NotImplementedError() + raise NotImplementedError - async def dispatch_error(self, ctx: ContextT, error: Exception) -> None: + async def dispatch_error(self, ctx: BaseContext, error: Exception) -> None: ctx.command_failed = True cog = self.cog From 656cf3728162c3f01ab227a773e953a52bda867b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 12 Nov 2022 17:35:16 +0000 Subject: [PATCH 34/54] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- discord/commands/context.py | 1 + discord/commands/core.py | 1 - discord/commands/mixins.py | 19 +++++++++++-------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/discord/commands/context.py b/discord/commands/context.py index 2269bdf787..18e2ac231b 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -303,6 +303,7 @@ async def delete(self, *, delay: float | None = None) -> None: def edit(self) -> Callable[..., Awaitable[InteractionMessage]]: return self.interaction.edit_original_response + class AutocompleteContext: """Represents context for a slash command's option autocomplete. This ***does not*** inherent from :class:`.BaseContext`. diff --git a/discord/commands/core.py b/discord/commands/core.py index d1ce89b9f7..028fe97ff3 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -157,7 +157,6 @@ def to_dict(self) -> dict[str, Any]: raise NotImplementedError - class SlashCommand(ApplicationCommand): """A class that implements the protocol for a slash command. diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index d62bbb410d..823b79f068 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -39,10 +39,10 @@ CommandOnCooldown, DisabledCommand, ) -from .cooldowns import BucketType, CooldownMapping, MaxConcurrency, Cooldown +from .cooldowns import BucketType, Cooldown, CooldownMapping, MaxConcurrency if TYPE_CHECKING: - from typing_extensions import ParamSpec, Concatenate + from typing_extensions import Concatenate, ParamSpec from ..abc import MessageableChannel from ..bot import AutoShardedBot, Bot @@ -65,7 +65,10 @@ T = TypeVar("T") Coro = Coroutine[Any, Any, T] -Callback = Callable[Concatenate[CogT, "BaseContext", P], Coro[T]] | Callable[Concatenate["BaseContext", P], Coro[T]] +Callback = ( + Callable[Concatenate[CogT, "BaseContext", P], Coro[T]] + | Callable[Concatenate["BaseContext", P], Coro[T]] +) MaybeCoro = Union[T, Coro[T]] Check = Union[ @@ -79,7 +82,9 @@ ] ErrorT = TypeVar("ErrorT", bound="Error") -Hook = Union[Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]]] +Hook = Union[ + Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]] +] HookT = TypeVar("HookT", bound="Hook") @@ -393,9 +398,7 @@ def __init__( ): self.callback: Callback = func self.parent: Invokable | None = ( - parent - if isinstance(parent, _BaseCommand) - else None + parent if isinstance(parent, _BaseCommand) else None ) self.cog: CogT | None = None self.module: Any = None @@ -534,7 +537,7 @@ async def __call__(self, ctx: BaseContext, *args: P.args, **kwargs: P.kwargs): new_args = (ctx, *args) if self.cog is not None: new_args = (self.cog, *args) - return await self.callback(*new_args, **kwargs) + return await self.callback(*new_args, **kwargs) def update(self, **kwargs: Any) -> None: """Updates the :class:`Invokable` instance with updated attribute. From c495966bf4045bc12eb73707ff91d0983078c1cf Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 12 Nov 2022 12:47:00 -0500 Subject: [PATCH 35/54] refactor(commands): the other function rename --- discord/commands/mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index d62bbb410d..a1bcde9d68 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -102,7 +102,7 @@ def unwrap_function(function: functools.partial | Callable) -> Callback: def wrap_callback(coro: Callback): @functools.wraps(coro) - async def wrapped(*args, **kwargs): + async def wrapper(*args, **kwargs): try: ret = await coro(*args, **kwargs) except CommandError: @@ -113,7 +113,7 @@ async def wrapped(*args, **kwargs): raise CommandInvokeError(exc) from exc return ret - return wrapped + return wrapper def hook_wrapped_callback(command: Invokable, ctx: BaseContext, coro: Callback): From ccc61b7b55f062b3af096e9909621b55ac39010c Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 14 Nov 2022 21:08:13 -0500 Subject: [PATCH 36/54] refactor(commands): avoid 'type: ignore' from suggestions Co-Authored-By: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> --- discord/commands/mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index f29bda75f0..389ddbc009 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -98,7 +98,7 @@ def unwrap_function(function: functools.partial | Callable) -> Callback: while True: if hasattr(function, "__wrapped__"): - function = function.__wrapped__ # type: ignore # function may or may not have attribute + function = getattr(function, "__wrapped__") elif isinstance(function, functools.partial): function = function.func else: From 36e9a9a389556bc759d30966f076c80627013836 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 14 Nov 2022 21:19:36 -0500 Subject: [PATCH 37/54] fix(ext.commands): add missing import --- discord/ext/commands/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index d9b9a85ddf..a9705467ee 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -196,6 +196,7 @@ async def send_help(self, *args: Any) -> Any: The result of the help command, if any. """ from ...commands.mixins import wrap_callback + from .core import Group from .errors import CommandError bot = self.bot From 2110284dbfebb0d1a074e495887e7837e8c94d5a Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 14 Nov 2022 21:23:31 -0500 Subject: [PATCH 38/54] fix(commands): move typehints because "Concatenate is not defined" error --- discord/commands/mixins.py | 54 +++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 389ddbc009..07e8bac3fd 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -56,37 +56,37 @@ from ..voice_client import VoiceProtocol P = ParamSpec("P") + + BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") + CogT = TypeVar("CogT", bound="Cog") + + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + Callback = ( + Callable[Concatenate[CogT, "BaseContext", P], Coro[T]] + | Callable[Concatenate["BaseContext", P], Coro[T]] + ) + MaybeCoro = Union[T, Coro[T]] + + Check = Union[ + Callable[[CogT, "BaseContext"], MaybeCoro[bool]], + Callable[["BaseContext"], MaybeCoro[bool]], + ] + + Error = Union[ + Callable[[CogT, "BaseContext[Any]", CommandError], Coro[Any]], + Callable[["BaseContext[Any]", CommandError], Coro[Any]], + ] + ErrorT = TypeVar("ErrorT", bound="Error") + + Hook = Union[ + Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]] + ] + HookT = TypeVar("HookT", bound="Hook") else: P = TypeVar("P") -BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") -CogT = TypeVar("CogT", bound="Cog") - -T = TypeVar("T") -Coro = Coroutine[Any, Any, T] -Callback = ( - Callable[Concatenate[CogT, "BaseContext", P], Coro[T]] - | Callable[Concatenate["BaseContext", P], Coro[T]] -) -MaybeCoro = Union[T, Coro[T]] - -Check = Union[ - Callable[[CogT, "BaseContext"], MaybeCoro[bool]], - Callable[["BaseContext"], MaybeCoro[bool]], -] - -Error = Union[ - Callable[[CogT, "BaseContext[Any]", CommandError], Coro[Any]], - Callable[["BaseContext[Any]", CommandError], Coro[Any]], -] -ErrorT = TypeVar("ErrorT", bound="Error") - -Hook = Union[ - Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]] -] -HookT = TypeVar("HookT", bound="Hook") - __all__ = ( "Invokable", From 3f0c9061fa2ca99dad1d689cb72da43382220d0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Nov 2022 02:23:57 +0000 Subject: [PATCH 39/54] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- discord/commands/mixins.py | 1 - 1 file changed, 1 deletion(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 07e8bac3fd..72e2381c60 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -87,7 +87,6 @@ P = TypeVar("P") - __all__ = ( "Invokable", "_BaseCommand", From 20d255cb9d4cc303322fe202186ffd92ddfde578 Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 14 Nov 2022 21:25:15 -0500 Subject: [PATCH 40/54] chore(workflows): update list of allowed words --- .github/workflows/check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d64c44f2a5..1e0818ad6f 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -20,7 +20,7 @@ jobs: python -m pip install --upgrade pip pip install -r requirements/dev.txt - run: - codespell --ignore-words-list="groupt,nd,ot,ro,falsy,BU" \ + codespell --ignore-words-list="groupt,nd,ot,ro,falsy,BU,invokable" \ --exclude-file=".github/workflows/codespell.yml" bandit: runs-on: ubuntu-latest From b51552af54d9502864dad1a585dbf4cbf04295aa Mon Sep 17 00:00:00 2001 From: Middledot Date: Mon, 14 Nov 2022 21:29:46 -0500 Subject: [PATCH 41/54] fix(commands): move typehints again --- discord/commands/mixins.py | 46 ++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 72e2381c60..38a107f57e 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -41,6 +41,29 @@ ) from .cooldowns import BucketType, Cooldown, CooldownMapping, MaxConcurrency +T = TypeVar("T") +BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") +CogT = TypeVar("CogT", bound="Cog") + +Coro = Coroutine[Any, Any, T] +MaybeCoro = Union[T, Coro[T]] + +Check = Union[ + Callable[[CogT, "BaseContext"], MaybeCoro[bool]], + Callable[["BaseContext"], MaybeCoro[bool]], +] + +Error = Union[ + Callable[[CogT, "BaseContext[Any]", CommandError], Coro[Any]], + Callable[["BaseContext[Any]", CommandError], Coro[Any]], +] +ErrorT = TypeVar("ErrorT", bound="Error") + +Hook = Union[ + Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]] +] +HookT = TypeVar("HookT", bound="Hook") + if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec @@ -57,34 +80,13 @@ P = ParamSpec("P") - BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") - CogT = TypeVar("CogT", bound="Cog") - - T = TypeVar("T") - Coro = Coroutine[Any, Any, T] Callback = ( Callable[Concatenate[CogT, "BaseContext", P], Coro[T]] | Callable[Concatenate["BaseContext", P], Coro[T]] ) - MaybeCoro = Union[T, Coro[T]] - - Check = Union[ - Callable[[CogT, "BaseContext"], MaybeCoro[bool]], - Callable[["BaseContext"], MaybeCoro[bool]], - ] - - Error = Union[ - Callable[[CogT, "BaseContext[Any]", CommandError], Coro[Any]], - Callable[["BaseContext[Any]", CommandError], Coro[Any]], - ] - ErrorT = TypeVar("ErrorT", bound="Error") - - Hook = Union[ - Callable[[CogT, "BaseContext"], Coro[Any]], Callable[["BaseContext"], Coro[Any]] - ] - HookT = TypeVar("HookT", bound="Hook") else: P = TypeVar("P") + Callback = TypeVar("Callback") __all__ = ( From 4a8e38e3c7561fff7243a7ab8377bd4b3b28bf39 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sun, 18 Dec 2022 12:46:43 -0500 Subject: [PATCH 42/54] fix(commands, ext.commands): apply suggestions --- discord/bot.py | 2 +- discord/commands/core.py | 2 +- discord/commands/mixins.py | 47 +++++++++++++++++++++++---------- discord/ext/commands/bot.py | 2 +- discord/ext/commands/cog.py | 4 +-- discord/ext/commands/context.py | 2 +- discord/ext/commands/core.py | 8 +++--- discord/ext/commands/help.py | 4 +-- 8 files changed, 45 insertions(+), 26 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index ddfcd1c99b..f2905ed0e8 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -1116,7 +1116,7 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None: else: raise CheckFailure("The global check once functions failed.") except DiscordException as exc: - await ctx.command.dispatch_error(ctx, exc) + await ctx.command._dispatch_error(ctx, exc) else: self._bot.dispatch("application_command_completion", ctx) diff --git a/discord/commands/core.py b/discord/commands/core.py index 028fe97ff3..350d404544 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -139,7 +139,7 @@ def __eq__(self, other) -> bool: def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) - async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + async def __dispatch_error(self, ctx: BaseContext, error: Exception) -> None: ctx.bot.dispatch("application_command_error", ctx, error) @property diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 38a107f57e..597adc3ac8 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -28,7 +28,7 @@ import asyncio import datetime import functools -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar, Union, overload from .. import abc, utils from ..errors import ( @@ -139,7 +139,7 @@ async def wrapper(*args, **kwargs): finally: if command._max_concurrency is not None: await command._max_concurrency.release(ctx) - await command.call_after_hooks(ctx) + await command._call_after_hooks(ctx) return ret @@ -537,10 +537,29 @@ async def __call__(self, ctx: BaseContext, *args: P.args, **kwargs: P.kwargs): """ new_args = (ctx, *args) if self.cog is not None: - new_args = (self.cog, *args) + new_args = (self.cog, *new_args) return await self.callback(*new_args, **kwargs) - def update(self, **kwargs: Any) -> None: + @overload + def update( + self, + *, + func: Callback | None = ..., + name: str | None = ..., + enabled: bool = False, + cooldown_after_parsing: bool = ..., + parent: Invokable | None = ..., + checks: list[Check] = ..., + cooldown: CooldownMapping | None = ..., + max_concurrency: MaxConcurrency | None = ..., + ) -> None: + ... + + @overload + def update(self) -> None: + ... + + def update(self, **kwargs) -> None: """Updates the :class:`Invokable` instance with updated attribute. Similar to creating a new instance except it updates the current. @@ -802,7 +821,7 @@ def _prepare_cooldowns(self, ctx: BaseContext): if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - async def call_before_hooks(self, ctx: BaseContext) -> None: + async def _call_before_hooks(self, ctx: BaseContext) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog @@ -827,7 +846,7 @@ async def call_before_hooks(self, ctx: BaseContext) -> None: if hook is not None: await hook(ctx) - async def call_after_hooks(self, ctx: BaseContext) -> None: + async def _call_after_hooks(self, ctx: BaseContext) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, "__self__", cog) @@ -850,7 +869,7 @@ async def _parse_arguments(self, ctx: BaseContext) -> None: """Parses arguments and attaches them to the context class (Union[:class:`~ext.commands.Context`, :class:`.ApplicationContext`])""" raise NotImplementedError - async def prepare(self, ctx: BaseContext) -> None: + async def _prepare(self, ctx: BaseContext) -> None: ctx.command = self if not await self.can_run(ctx): @@ -870,7 +889,7 @@ async def prepare(self, ctx: BaseContext) -> None: self._prepare_cooldowns(ctx) await self._parse_arguments(ctx) - await self.call_before_hooks(ctx) + await self._call_before_hooks(ctx) except: if self._max_concurrency is not None: await self._max_concurrency.release(ctx) # type: ignore @@ -884,7 +903,7 @@ async def invoke(self, ctx: BaseContext) -> None: ctx: :class:`.BaseContext` The context to pass into the command. """ - await self.prepare(ctx) + await self._prepare(ctx) # terminate the invoked_subcommand chain. # since we're in a regular command (and not a group) then @@ -914,7 +933,7 @@ async def reinvoke(self, ctx: BaseContext, *, call_hooks: bool = False) -> None: await self._parse_arguments(ctx) if call_hooks: - await self.call_before_hooks(ctx) + await self._call_before_hooks(ctx) ctx.invoked_subcommand = None try: @@ -924,14 +943,14 @@ async def reinvoke(self, ctx: BaseContext, *, call_hooks: bool = False) -> None: raise finally: if call_hooks: - await self.call_after_hooks(ctx) + await self._call_after_hooks(ctx) - async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + async def __dispatch_error(self, ctx: BaseContext, error: Exception) -> None: # since I don't want to copy paste code, subclassed Contexts # dispatch it to their corresponding events raise NotImplementedError - async def dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: ctx.command_failed = True cog = self.cog @@ -949,4 +968,4 @@ async def dispatch_error(self, ctx: BaseContext, error: Exception) -> None: wrapped = wrap_callback(local) await wrapped(ctx, error) finally: - await self._dispatch_error(ctx, error) + await self.__dispatch_error(ctx, error) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index ae8a48c283..c1637ab3fd 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -348,7 +348,7 @@ async def invoke(self, ctx: Context) -> None: else: raise errors.CheckFailure("The global check once functions failed.") except errors.CommandError as exc: - await ctx.command.dispatch_error(ctx, exc) + await ctx.command._dispatch_error(ctx, exc) else: self.dispatch("command_completion", ctx) elif ctx.invoked_with: diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 5f9eacb72a..b0bba0b8c8 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -54,7 +54,7 @@ def walk_commands(self) -> Generator[Invokable, None, None]: Yields ------ - Union[:class:`.Invokable`] + :class:`.Invokable` A command or group from the cog. """ from .core import GroupMixin @@ -74,7 +74,7 @@ def get_commands(self) -> list[ApplicationCommand | Command]: r""" Returns -------- - List[:class:`~discord.Invokable`]] + List[:class:`.Invokable`] A :class:`list` of commands that are defined inside this cog. .. note:: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index a9705467ee..6580b965e0 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -96,7 +96,7 @@ def __init__( prefix: str | None = None, command: Command | None = None, current_parameter: inspect.Parameter | None = None, - **kwargs2, + **kwargs2: dict[str, Any], ): super().__init__(bot=bot, command=command, args=args, kwargs=kwargs, **kwargs2) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index e2de60475b..2c2752cfe3 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -289,7 +289,7 @@ def callback(self, func: CallbackT) -> None: self.params = get_signature_parameters(func, globalns) - async def _dispatch_error(self, ctx: Context, error: Exception) -> None: + async def __dispatch_error(self, ctx: Context, error: Exception) -> None: ctx.bot.dispatch("command_error", ctx, error) async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: @@ -871,7 +871,7 @@ async def invoke(self, ctx: Context) -> None: ctx.subcommand_passed = None early_invoke = not self.invoke_without_command if early_invoke: - await self.prepare(ctx) + await self._prepare(ctx) view = ctx.view previous = view.index @@ -905,7 +905,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: await self._parse_arguments(ctx) if call_hooks: - await self.call_before_hooks(ctx) + await self._call_before_hooks(ctx) view = ctx.view previous = view.index @@ -924,7 +924,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: raise finally: if call_hooks: - await self.call_after_hooks(ctx) + await self._call_after_hooks(ctx) ctx.invoked_parents.append(ctx.invoked_with) # type: ignore diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 90899ef6a8..7edb8a5a00 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -199,7 +199,7 @@ def __init__(self, inject, *args, **kwargs): self._original = inject self._injected = inject - async def prepare(self, ctx): + async def _prepare(self, ctx): self._injected = injected = self._original.copy() injected.context = ctx self.callback = injected.command_callback @@ -211,7 +211,7 @@ async def prepare(self, ctx): else: self.on_error = on_error - await super().prepare(ctx) + await super()._prepare(ctx) async def _parse_arguments(self, ctx): # Make the parser think we don't have a cog so it doesn't From 00acc4eb4cff712a271d0beb91c45dec791dc133 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Dec 2022 17:51:52 +0000 Subject: [PATCH 43/54] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- discord/commands/mixins.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 597adc3ac8..12b073ad6b 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -28,7 +28,16 @@ import asyncio import datetime import functools -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Generic, + TypeVar, + Union, + overload, +) from .. import abc, utils from ..errors import ( From 791aef43636a1e759e50286f36b1a3da586ca256 Mon Sep 17 00:00:00 2001 From: Middledot Date: Fri, 28 Apr 2023 23:01:00 -0400 Subject: [PATCH 44/54] fix(commands): silly error where enabled is False --- discord/commands/mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 12b073ad6b..6c24b26e5d 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -399,7 +399,7 @@ def __init__( self, func: Callback, name: str | None = None, - enabled: bool = False, + enabled: bool = True, cooldown_after_parsing: bool = False, parent: Invokable | None = None, checks: list[Check] = [], From c0ea4e1f355a5fa3e2d79bac7161016fe4a9a1e3 Mon Sep 17 00:00:00 2001 From: Middledot Date: Fri, 28 Apr 2023 23:01:44 -0400 Subject: [PATCH 45/54] fix(commands): `.help` is text exclusive --- discord/ext/commands/core.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 44c2bdcf0e..c05253518b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -241,9 +241,7 @@ def __init__( func: CallbackT, **kwargs: Any, ): - super().__init__(func, **kwargs) - - help_doc = kwargs.get("help") + help_doc = kwargs.pop("help", None) if help_doc is not None: help_doc = inspect.cleandoc(help_doc) else: @@ -252,23 +250,24 @@ def __init__( help_doc = help_doc.decode("utf-8") self.help: str | None = help_doc - - self.brief: str | None = kwargs.get("brief") - self.usage: str | None = kwargs.get("usage") - self.rest_is_raw: bool = kwargs.get("rest_is_raw", False) - self.aliases: list[str] | tuple[str] = kwargs.get("aliases", []) - self.extras: dict[str, Any] = kwargs.get("extras", {}) + self.brief: str | None = kwargs.pop("brief", None) + self.usage: str | None = kwargs.pop("usage", None) + self.rest_is_raw: bool = kwargs.pop("rest_is_raw", False) + self.aliases: list[str] | tuple[str] = kwargs.pop("aliases", []) + self.extras: dict[str, Any] = kwargs.pop("extras", {}) if not isinstance(self.aliases, (list, tuple)): raise TypeError( "Aliases of a command must be a list or a tuple of strings." ) - self.description: str = inspect.cleandoc(kwargs.get("description", "")) - self.hidden: bool = kwargs.get("hidden", False) + self.description: str = inspect.cleandoc(kwargs.pop("description", "")) + self.hidden: bool = kwargs.pop("hidden", False) - self.require_var_positional: bool = kwargs.get("require_var_positional", False) - self.ignore_extra: bool = kwargs.get("ignore_extra", True) + self.require_var_positional: bool = kwargs.pop("require_var_positional", False) + self.ignore_extra: bool = kwargs.pop("ignore_extra", True) + + super().__init__(func, **kwargs) @property def callback( @@ -289,7 +288,7 @@ def callback(self, func: CallbackT) -> None: self.params = get_signature_parameters(func, globalns) - async def __dispatch_error(self, ctx: Context, error: Exception) -> None: + async def _dispatch_error(self, ctx: Context, error: Exception) -> None: ctx.bot.dispatch("command_error", ctx, error) async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: From 9141c45b93c50aa8c29c514fcff00a8975c22ba8 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 29 Apr 2023 00:41:04 -0400 Subject: [PATCH 46/54] chore(commands): efficientify the copy func --- discord/cog.py | 2 +- discord/commands/mixins.py | 40 ++++++++++--------------------------- discord/ext/commands/bot.py | 2 +- 3 files changed, 12 insertions(+), 32 deletions(-) diff --git a/discord/cog.py b/discord/cog.py index 811ccfe1bd..25550a461a 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -231,7 +231,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: # Either update the command with the cog provided defaults or copy it. # r.e type ignore, type-checker complains about overriding a ClassVar - new_cls.__cog_commands__ = tuple(c._update_copy(cmd_attrs) if not hasattr(c, "add_to") else c for c in new_cls.__cog_commands__) # type: ignore + new_cls.__cog_commands__ = tuple(c.copy(cmd_attrs) if not hasattr(c, "add_to") else c for c in new_cls.__cog_commands__) # type: ignore name_filter = lambda c: ( "app" diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 6c24b26e5d..469917dcf2 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -28,6 +28,7 @@ import asyncio import datetime import functools +from copy import copy from typing import ( TYPE_CHECKING, Any, @@ -716,7 +717,7 @@ def remove_check(self, func: Check) -> None: except ValueError: pass - def copy(self): + def copy(self, kwargs: dict[str, Any] | None = None): """Creates a copy of this command. Returns @@ -724,34 +725,13 @@ def copy(self): :class:`Invokable` A new instance of this command. """ - ret = self.__class__(self.callback, **self.__original_kwargs__) - return self._ensure_assignment_on_copy(ret) - - def _ensure_assignment_on_copy(self, other: Invokable): - other._before_invoke = self._before_invoke - other._after_invoke = self._after_invoke - if self.checks != other.checks: - other.checks = self.checks.copy() - if self._buckets.valid and not other._buckets.valid: - other._buckets = self._buckets.copy() - if self._max_concurrency != other._max_concurrency: - # _max_concurrency won't be None at this point - other._max_concurrency = self._max_concurrency.copy() # type: ignore - try: - other.on_error = self.on_error - except AttributeError: - pass - return other - - def _update_copy(self, kwargs: dict[str, Any]): + thecopy = copy(self) if kwargs: - kw = kwargs.copy() - kw.update(self.__original_kwargs__) - copy = self.__class__(self.callback, **kw) - return self._ensure_assignment_on_copy(copy) - else: - return self.copy() + for attr, val in kwargs.items(): + setattr(thecopy, attr, val) + + return thecopy def _set_cog(self, cog: CogT): self.cog = cog @@ -954,12 +934,12 @@ async def reinvoke(self, ctx: BaseContext, *, call_hooks: bool = False) -> None: if call_hooks: await self._call_after_hooks(ctx) - async def __dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: # since I don't want to copy paste code, subclassed Contexts # dispatch it to their corresponding events raise NotImplementedError - async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + async def dispatch_error(self, ctx: BaseContext, error: Exception) -> None: ctx.command_failed = True cog = self.cog @@ -977,4 +957,4 @@ async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: wrapped = wrap_callback(local) await wrapped(ctx, error) finally: - await self.__dispatch_error(ctx, error) + await self._dispatch_error(ctx, error) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index f69807d72c..4f1753e49a 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -348,7 +348,7 @@ async def invoke(self, ctx: Context) -> None: else: raise errors.CheckFailure("The global check once functions failed.") except errors.CommandError as exc: - await ctx.command._dispatch_error(ctx, exc) + await ctx.command.dispatch_error(ctx, exc) else: self.dispatch("command_completion", ctx) elif ctx.invoked_with: From 85abb41e10eaefc016b5e38c91018c9a182f6578 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 29 Apr 2023 00:48:40 -0400 Subject: [PATCH 47/54] chore(errors): deprecate ACInvokeError and ACError --- discord/errors.py | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/discord/errors.py b/discord/errors.py index 487b8c69d0..f3efc25672 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -296,17 +296,6 @@ def __init__(self, message: str | None = None, *args: Any) -> None: super().__init__(*args) -class ApplicationCommandError(CommandError): - r"""The base exception type for all application command related errors. - - This inherits from :exc:`DiscordException`. - - This exception and exceptions inherited from it are handled - in a special way as they are caught and passed into a special event - from :class:`.Bot`\, :func:`.on_command_error`. - """ - - class CommandInvokeError(CommandError): """Exception raised when the command being invoked raised an exception. @@ -324,23 +313,8 @@ def __init__(self, e: Exception) -> None: super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") -class ApplicationCommandInvokeError(ApplicationCommandError): - """Exception raised when the command being invoked raised an exception. - - This inherits from :exc:`ApplicationCommandError` - - Attributes - ---------- - original: :exc:`Exception` - The original exception that was raised. You can also get this via - the ``__cause__`` attribute. - """ - - def __init__(self, e: Exception) -> None: - self.original: Exception = e - super().__init__( - f"Application Command raised an exception: {e.__class__.__name__}: {e}" - ) +ApplicationCommandError = CommandError +ApplicationCommandInvokeError = CommandInvokeError class CheckFailure(CommandError): From d8bf4bc2d1793a5b021476dc946423aeefb182e1 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 29 Apr 2023 00:58:55 -0400 Subject: [PATCH 48/54] chore(commands): add last dispatch_error changes --- discord/bot.py | 2 +- discord/commands/core.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index 7bbf9ec4a0..0e8cad3036 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -1116,7 +1116,7 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None: else: raise CheckFailure("The global check once functions failed.") except DiscordException as exc: - await ctx.command._dispatch_error(ctx, exc) + await ctx.command.dispatch_error(ctx, exc) else: self._bot.dispatch("application_command_completion", ctx) diff --git a/discord/commands/core.py b/discord/commands/core.py index 0856f9998b..c4c7d0a201 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -139,7 +139,7 @@ def __eq__(self, other) -> bool: def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) - async def __dispatch_error(self, ctx: BaseContext, error: Exception) -> None: + async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: ctx.bot.dispatch("application_command_error", ctx, error) @property From ce44ccbe2523ec41ff822ffa80828c6e2dcc0df6 Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 10 Jun 2023 21:01:19 -0400 Subject: [PATCH 49/54] fix(cmds): fix 3 bugs found while testing 1. Parameters that can't be used by slash commands (like aliases) go to Invokable and fail - Added **kwargs as catchall 2. Checks pile up on the same object (the default) causing all commands to have the same checks (quite silly) - Just make the default None 3. Groups have duplicates cuz of copying - Don't duplicate them --- discord/commands/mixins.py | 12 ++++++++---- discord/ext/commands/core.py | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 469917dcf2..c897e82a3b 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -403,9 +403,10 @@ def __init__( enabled: bool = True, cooldown_after_parsing: bool = False, parent: Invokable | None = None, - checks: list[Check] = [], + checks: list[Check] | None = None, cooldown: CooldownMapping | None = None, max_concurrency: MaxConcurrency | None = None, + **kwargs ): self.callback: Callback = func self.parent: Invokable | None = ( @@ -418,6 +419,9 @@ def __init__( self.enabled: bool = enabled self.cooldown_after_parsing: bool = cooldown_after_parsing + if not checks: + checks = [] + # checks if _checks := getattr(func, "__commands_checks__", []): # combine all that we find (kwargs or decorator) @@ -717,7 +721,7 @@ def remove_check(self, func: Check) -> None: except ValueError: pass - def copy(self, kwargs: dict[str, Any] | None = None): + def copy(self, overrides: dict[str, Any] | None = None): """Creates a copy of this command. Returns @@ -727,8 +731,8 @@ def copy(self, kwargs: dict[str, Any] | None = None): """ thecopy = copy(self) - if kwargs: - for attr, val in kwargs.items(): + if overrides: + for attr, val in overrides.items(): setattr(thecopy, attr, val) return thecopy diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index c05253518b..f764135531 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -852,7 +852,7 @@ def __init__(self, *args: Any, **attrs: Any) -> None: self.invoke_without_command: bool = attrs.pop("invoke_without_command", False) super().__init__(*args, **attrs) - def copy(self: GroupT) -> GroupT: + def copy(self: GroupT, overrides: dict[str, Any] | None = None) -> GroupT: """Creates a copy of this :class:`Group`. Returns @@ -860,7 +860,8 @@ def copy(self: GroupT) -> GroupT: :class:`Group` A new instance of this group. """ - ret = super().copy() + ret = super().copy(overrides) + ret.recursively_remove_all_commands() for cmd in self.commands: ret.add_command(cmd.copy()) return ret # type: ignore From 45c2e96dee91fe48ce20e68756ad883739ee5793 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Jun 2023 01:02:01 +0000 Subject: [PATCH 50/54] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/commands/mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index c897e82a3b..d13d52b8a1 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -406,7 +406,7 @@ def __init__( checks: list[Check] | None = None, cooldown: CooldownMapping | None = None, max_concurrency: MaxConcurrency | None = None, - **kwargs + **kwargs, ): self.callback: Callback = func self.parent: Invokable | None = ( From d9d536c583796f9a6f0c4e84d8a4f9f1e68307bc Mon Sep 17 00:00:00 2001 From: Middledot Date: Tue, 27 Jun 2023 18:28:42 -0400 Subject: [PATCH 51/54] chore(*): cleanup and changelog - Made a changelog - Fixed some docs and typing of functions - Revert renaming prepare and call_*_hooks - Revert "slash cmds can't run their parents checks" --- CHANGELOG.md | 25 +++++++++++ discord/bot.py | 17 ++++---- discord/commands/core.py | 64 +++++++++++++++++++++++++++- discord/commands/mixins.py | 70 ++++--------------------------- discord/errors.py | 1 - discord/ext/commands/cooldowns.py | 2 + discord/ext/commands/core.py | 64 ++++++++++++++++++++++++++-- discord/ext/commands/errors.py | 2 +- discord/ext/commands/help.py | 4 +- 9 files changed, 170 insertions(+), 79 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c94cf00022..94e5920469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,24 @@ These changes are available on the `master` branch, but have not yet been releas ([#2086](https://github.com/Pycord-Development/pycord/pull/2086)) - Added new embedded activities, Gartic Phone and Jamspace. ([#2102](https://github.com/Pycord-Development/pycord/pull/2102)) +- Added `_parse_arguments` function to slash commands to parse arguments instead of in + `prepare`. ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Added the `parents`, `root_parent`, and `cog_name` attribute to slash commands. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Added the `error` decorator to slash commands. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Added the `cooldown_after_parsing` parameter & attribute to slash commands. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Added the `reinvoke` function to `ApplicationContext`. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Added the `invoked_with`, `invoked_parents`, `invoked_subcommand`, `subcommand_passed`, + and `command_failed` parameters & attributes to `ApplicationContext`. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Added the `source` attribute to `ext.commands.Context` & `ApplicationContext` for a + common way to either retrieve the message or the interaction that triggered the + command. ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Documented `ContextMenuCommand`. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) ### Changed @@ -84,6 +102,13 @@ These changes are available on the `master` branch, but have not yet been releas ([#2087](https://github.com/Pycord-Development/pycord/pull/2087)) - Typehinted `command_prefix` and `help_command` arguments properly. ([#2099](https://github.com/Pycord-Development/pycord/pull/2099)) +- Renamed `_invoke` in application commands to `invoke`. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Deprecated `ext/commands/cooldowns` in favour of `commands/cooldowns`. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) +- Deprecated `ApplicationCommandError` & `ApplicationCommandInvokeError` in favour of + `CommandError` & `CommandInvokeError` respectively. + ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) ### Removed diff --git a/discord/bot.py b/discord/bot.py index e0a61074b9..1669574886 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -34,11 +34,12 @@ import sys import traceback from abc import ABC, abstractmethod -from typing import Any, Callable, Coroutine, Generator, Literal, Mapping, TypeVar +from typing import Any, Callable, Coroutine, Generator, Literal, Mapping, TypeVar, Type from .client import Client from .cog import CogMixin from .commands import ( + BaseContext, ApplicationCommand, ApplicationContext, AutocompleteContext, @@ -1222,8 +1223,8 @@ def check(self, func): .. note:: This function can either be a regular function or a coroutine. Similar to a command :func:`.check`, this - takes a single parameter of type :class:`.Context` and can only raise exceptions inherited from - :exc:`.ApplicationCommandError`. + takes a single parameter of type :class:`.BaseContext` and can only raise exceptions inherited from + :exc:`.CommandError`. Example ------- @@ -1282,14 +1283,14 @@ def check_once(self, func): .. note:: - When using this function the :class:`.Context` sent to a group subcommand may only parse the parent command + When using this function the :class:`.BaseContext` sent to a group subcommand may only parse the parent command and not the subcommands due to it being invoked once per :meth:`.Bot.invoke` call. .. note:: This function can either be a regular function or a coroutine. Similar to a command :func:`.check`, - this takes a single parameter of type :class:`.Context` and can only raise exceptions inherited from - :exc:`.ApplicationCommandError`. + this takes a single parameter of type :class:`.BaseContext` and can only raise exceptions inherited from + :exc:`.CommandError`. Example ------- @@ -1318,7 +1319,7 @@ def before_invoke(self, coro): A pre-invoke hook is called directly before the command is called. This makes it a useful function to set up database connections or any type of set up required. - This pre-invoke hook takes a sole parameter, a :class:`.Context`. + This pre-invoke hook takes a sole parameter, a :class:`.BaseContext`. .. note:: @@ -1348,7 +1349,7 @@ def after_invoke(self, coro): A post-invoke hook is called directly after the command is called. This makes it a useful function to clean-up database connections or any type of clean up required. - This post-invoke hook takes a sole parameter, a :class:`.Context`. + This post-invoke hook takes a sole parameter, a :class:`.BaseContext`. .. note:: diff --git a/discord/commands/core.py b/discord/commands/core.py index 6be4afd4ca..24138fc04d 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -33,10 +33,11 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Generator, Generic, TypeVar, Union +from discord import utils from ..channel import _threaded_guild_channel_factory from ..enums import Enum as DiscordEnum from ..enums import MessageType, SlashCommandOptionType, try_enum -from ..errors import ClientException, ValidationError +from ..errors import ClientException, ValidationError, CheckFailure, DisabledCommand from ..member import Member from ..message import Attachment, Message from ..object import Object @@ -142,6 +143,67 @@ def _get_signature_parameters(self): async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: ctx.bot.dispatch("application_command_error", ctx, error) + async def can_run(self, ctx: ApplicationContext) -> bool: + """|coro| + + Checks if the command can be executed by checking all the predicates + inside the :attr:`~ApplicationCommand.checks` attribute. This also checks whether the + command is disabled. + + .. versionchanged:: 1.3 + Checks whether the command is disabled or not + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The ctx of the command currently being invoked. + + Returns + ------- + :class:`bool` + A boolean indicating if the command can be invoked. + + Raises + ------ + :class:`CommandError` + Any command error that was raised during a check call will be propagated + by this function. + """ + + if not self.enabled: + raise DisabledCommand(f"{self.name} command is disabled") + + original = ctx.command + ctx.command = self + + try: + if not await ctx.bot.can_run(ctx): + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) + + # since slash command parents don't really use checks, we can make it + # a feature to have "global" checks for slash commands only + predicates = self.checks + if self.parent is not None: + predicates = self.parent.checks + predicates + + if (cog := self.cog) and ( + local_check := cog._get_overridden_method(cog.cog_check) + ): + ret = await utils.maybe_coroutine(local_check, ctx) + if not ret: + return False + + predicates = self.checks + if not predicates: + # since we have no checks, then we just return True. + return True + + return await utils.async_all(predicate(ctx) for predicate in predicates) + finally: + ctx.command = original + @property def qualified_id(self) -> int: """Retrieves the fully qualified command ID. diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index d13d52b8a1..9b49eb9cd2 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -149,7 +149,7 @@ async def wrapper(*args, **kwargs): finally: if command._max_concurrency is not None: await command._max_concurrency.release(ctx) - await command._call_after_hooks(ctx) + await command.call_after_hooks(ctx) return ret @@ -637,60 +637,6 @@ def after_invoke(self, coro: HookT) -> HookT: self._after_invoke = coro return coro - async def can_run(self, ctx: BaseContext) -> bool: - """|coro| - - Checks if the command can be executed by checking all the predicates - inside the :attr:`~Command.checks` attribute. This also checks whether the - command is disabled. - - .. versionchanged:: 1.3 - Checks whether the command is disabled or not - - Parameters - ---------- - ctx: :class:`.Context` - The ctx of the command currently being invoked. - - Returns - ------- - :class:`bool` - A boolean indicating if the command can be invoked. - - Raises - ------ - :class:`CommandError` - Any command error that was raised during a check call will be propagated - by this function. - """ - if not self.enabled: - raise DisabledCommand(f"{self.name} command is disabled") - - original = ctx.command - ctx.command = self - - try: - if not await ctx.bot.can_run(ctx): - raise CheckFailure( - f"The global check functions for command {self.qualified_name} failed." - ) - - if (cog := self.cog) and ( - local_check := cog._get_overridden_method(cog.cog_check) - ): - ret = await utils.maybe_coroutine(local_check, ctx) - if not ret: - return False - - predicates = self.checks - if not predicates: - # since we have no checks, then we just return True. - return True - - return await utils.async_all(predicate(ctx) for predicate in predicates) - finally: - ctx.command = original - def add_check(self, func: Check) -> None: """Adds a check to the command. @@ -814,7 +760,7 @@ def _prepare_cooldowns(self, ctx: BaseContext): if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - async def _call_before_hooks(self, ctx: BaseContext) -> None: + async def call_before_hooks(self, ctx: BaseContext) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog @@ -839,7 +785,7 @@ async def _call_before_hooks(self, ctx: BaseContext) -> None: if hook is not None: await hook(ctx) - async def _call_after_hooks(self, ctx: BaseContext) -> None: + async def call_after_hooks(self, ctx: BaseContext) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, "__self__", cog) @@ -862,7 +808,7 @@ async def _parse_arguments(self, ctx: BaseContext) -> None: """Parses arguments and attaches them to the context class (Union[:class:`~ext.commands.Context`, :class:`.ApplicationContext`])""" raise NotImplementedError - async def _prepare(self, ctx: BaseContext) -> None: + async def prepare(self, ctx: BaseContext) -> None: ctx.command = self if not await self.can_run(ctx): @@ -882,7 +828,7 @@ async def _prepare(self, ctx: BaseContext) -> None: self._prepare_cooldowns(ctx) await self._parse_arguments(ctx) - await self._call_before_hooks(ctx) + await self.call_before_hooks(ctx) except: if self._max_concurrency is not None: await self._max_concurrency.release(ctx) # type: ignore @@ -896,7 +842,7 @@ async def invoke(self, ctx: BaseContext) -> None: ctx: :class:`.BaseContext` The context to pass into the command. """ - await self._prepare(ctx) + await self.prepare(ctx) # terminate the invoked_subcommand chain. # since we're in a regular command (and not a group) then @@ -926,7 +872,7 @@ async def reinvoke(self, ctx: BaseContext, *, call_hooks: bool = False) -> None: await self._parse_arguments(ctx) if call_hooks: - await self._call_before_hooks(ctx) + await self.call_before_hooks(ctx) ctx.invoked_subcommand = None try: @@ -936,7 +882,7 @@ async def reinvoke(self, ctx: BaseContext, *, call_hooks: bool = False) -> None: raise finally: if call_hooks: - await self._call_after_hooks(ctx) + await self.call_after_hooks(ctx) async def _dispatch_error(self, ctx: BaseContext, error: Exception) -> None: # since I don't want to copy paste code, subclassed Contexts diff --git a/discord/errors.py b/discord/errors.py index f3efc25672..f83c5acab2 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -312,7 +312,6 @@ def __init__(self, e: Exception) -> None: self.original: Exception = e super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") - ApplicationCommandError = CommandError ApplicationCommandInvokeError = CommandInvokeError diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 254a616da5..bac752e19f 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -24,6 +24,8 @@ """ # Cooldowns were moved to discord/commands/cooldowns.py +# This file acts as an alias for now + from ...commands.cooldowns import * __all__ = ( diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index f764135531..3c46b7af52 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -40,8 +40,9 @@ ) import discord - +from discord import utils from ...commands import ( + BaseContext, ApplicationCommand, _BaseCommand, message_command, @@ -549,6 +550,61 @@ def signature(self) -> str: return " ".join(result) + async def can_run(self, ctx: Context) -> bool: + """|coro| + + Checks if the command can be executed by checking all the predicates + inside the :attr:`~Command.checks` attribute. This also checks whether the + command is disabled. + + .. versionchanged:: 1.3 + Checks whether the command is disabled or not + + Parameters + ---------- + ctx: :class:`.Context` + The ctx of the command currently being invoked. + + Returns + ------- + :class:`bool` + A boolean indicating if the command can be invoked. + + Raises + ------ + :class:`CommandError` + Any command error that was raised during a check call will be propagated + by this function. + """ + + if not self.enabled: + raise DisabledCommand(f"{self.name} command is disabled") + + original = ctx.command + ctx.command = self + + try: + if not await ctx.bot.can_run(ctx): + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) + + if (cog := self.cog) and ( + local_check := cog._get_overridden_method(cog.cog_check) + ): + ret = await utils.maybe_coroutine(local_check, ctx) + if not ret: + return False + + predicates = self.checks + if not predicates: + # since we have no checks, then we just return True. + return True + + return await utils.async_all(predicate(ctx) for predicate in predicates) + finally: + ctx.command = original + class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave @@ -871,7 +927,7 @@ async def invoke(self, ctx: Context) -> None: ctx.subcommand_passed = None early_invoke = not self.invoke_without_command if early_invoke: - await self._prepare(ctx) + await self.prepare(ctx) view = ctx.view previous = view.index @@ -905,7 +961,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: await self._parse_arguments(ctx) if call_hooks: - await self._call_before_hooks(ctx) + await self.call_before_hooks(ctx) view = ctx.view previous = view.index @@ -924,7 +980,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: raise finally: if call_hooks: - await self._call_after_hooks(ctx) + await self.call_after_hooks(ctx) ctx.invoked_parents.append(ctx.invoked_with) # type: ignore diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 9f37a8205a..8acab9c04b 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -28,9 +28,9 @@ from typing import TYPE_CHECKING, Any, Callable from discord.errors import ( + CommandError, CheckFailure, ClientException, - CommandError, CommandInvokeError, CommandOnCooldown, DisabledCommand, diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 7edb8a5a00..90899ef6a8 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -199,7 +199,7 @@ def __init__(self, inject, *args, **kwargs): self._original = inject self._injected = inject - async def _prepare(self, ctx): + async def prepare(self, ctx): self._injected = injected = self._original.copy() injected.context = ctx self.callback = injected.command_callback @@ -211,7 +211,7 @@ async def _prepare(self, ctx): else: self.on_error = on_error - await super()._prepare(ctx) + await super().prepare(ctx) async def _parse_arguments(self, ctx): # Make the parser think we don't have a cog so it doesn't From 3645e79888a98fb3209667904679f3c692e85c37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jun 2023 22:32:43 +0000 Subject: [PATCH 52/54] style(pre-commit): auto fixes from pre-commit.com hooks --- CHANGELOG.md | 5 +++-- discord/bot.py | 4 ++-- discord/commands/core.py | 3 ++- discord/errors.py | 1 + discord/ext/commands/core.py | 3 ++- discord/ext/commands/errors.py | 2 +- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a45253532e..57d3131740 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,8 +81,9 @@ These changes are available on the `master` branch, but have not yet been releas ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) - Added the `reinvoke` function to `ApplicationContext`. ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) -- Added the `invoked_with`, `invoked_parents`, `invoked_subcommand`, `subcommand_passed`, - and `command_failed` parameters & attributes to `ApplicationContext`. +- Added the `invoked_with`, `invoked_parents`, `invoked_subcommand`, + `subcommand_passed`, and `command_failed` parameters & attributes to + `ApplicationContext`. ([#1606](https://github.com/Pycord-Development/pycord/pull/1606)) - Added the `source` attribute to `ext.commands.Context` & `ApplicationContext` for a common way to either retrieve the message or the interaction that triggered the diff --git a/discord/bot.py b/discord/bot.py index 1669574886..27db316677 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -34,15 +34,15 @@ import sys import traceback from abc import ABC, abstractmethod -from typing import Any, Callable, Coroutine, Generator, Literal, Mapping, TypeVar, Type +from typing import Any, Callable, Coroutine, Generator, Literal, Mapping, TypeVar from .client import Client from .cog import CogMixin from .commands import ( - BaseContext, ApplicationCommand, ApplicationContext, AutocompleteContext, + BaseContext, MessageCommand, SlashCommand, SlashCommandGroup, diff --git a/discord/commands/core.py b/discord/commands/core.py index 205b1a3a64..e5a8eff5d7 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -35,10 +35,11 @@ from typing import TYPE_CHECKING, Any, Callable, Generator, Generic, TypeVar, Union from discord import utils + from ..channel import _threaded_guild_channel_factory from ..enums import Enum as DiscordEnum from ..enums import MessageType, SlashCommandOptionType, try_enum -from ..errors import ClientException, ValidationError, CheckFailure, DisabledCommand +from ..errors import CheckFailure, ClientException, DisabledCommand, ValidationError from ..member import Member from ..message import Attachment, Message from ..object import Object diff --git a/discord/errors.py b/discord/errors.py index f83c5acab2..f3efc25672 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -312,6 +312,7 @@ def __init__(self, e: Exception) -> None: self.original: Exception = e super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") + ApplicationCommandError = CommandError ApplicationCommandInvokeError = CommandInvokeError diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 3c46b7af52..48b701b466 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -41,9 +41,10 @@ import discord from discord import utils + from ...commands import ( - BaseContext, ApplicationCommand, + BaseContext, _BaseCommand, message_command, slash_command, diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 8acab9c04b..9f37a8205a 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -28,9 +28,9 @@ from typing import TYPE_CHECKING, Any, Callable from discord.errors import ( - CommandError, CheckFailure, ClientException, + CommandError, CommandInvokeError, CommandOnCooldown, DisabledCommand, From e11530573c3e063e787ea0b442330ec97c439b3b Mon Sep 17 00:00:00 2001 From: Middledot Date: Sat, 8 Jul 2023 21:04:33 -0400 Subject: [PATCH 53/54] fix(commands): missing things (@error & cooldowns) --- discord/commands/core.py | 1 + discord/commands/mixins.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/discord/commands/core.py b/discord/commands/core.py index e5a8eff5d7..dc415e2b24 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -76,6 +76,7 @@ from .. import Permissions from .mixins import BaseContext + from .cooldowns import CooldownMapping, MaxConcurrency P = ParamSpec("P") else: diff --git a/discord/commands/mixins.py b/discord/commands/mixins.py index 9b49eb9cd2..095810773d 100644 --- a/discord/commands/mixins.py +++ b/discord/commands/mixins.py @@ -637,6 +637,30 @@ def after_invoke(self, coro: HookT) -> HookT: self._after_invoke = coro return coro + def error(self, coro: ErrorT) -> ErrorT: + """A decorator that registers a coroutine as a local error handler. + A local error handler is an :func:`.on_command_error`/ + :func:`.on_application_command_error` event limited to a single command. + However, the actual :func:`.on_command_error`/:func:`.on_application_command_error` + is still invoked afterwards as the catch-all. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The error handler must be a coroutine.") + + self.on_error: Error = coro + return coro + def add_check(self, func: Check) -> None: """Adds a check to the command. From 0ac4b47ffd6f607812d4f9b39506681a072739c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Jul 2023 01:05:06 +0000 Subject: [PATCH 54/54] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/commands/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index dc415e2b24..bf7709a092 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -75,8 +75,8 @@ from typing_extensions import ParamSpec from .. import Permissions - from .mixins import BaseContext from .cooldowns import CooldownMapping, MaxConcurrency + from .mixins import BaseContext P = ParamSpec("P") else: