diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index 2d7ff5497e..669dd60e04 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -27,7 +27,12 @@ ) import disnake -from disnake.utils import _generated, _overload_with_permissions +from disnake.utils import ( + _generated, + _overload_with_permissions, + get_signature_parameters, + unwrap_function, +) from ._types import _BaseCommand from .cog import Cog @@ -114,42 +119,6 @@ P = TypeVar("P") -def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: - partial = functools.partial - while True: - if hasattr(function, "__wrapped__"): - function = function.__wrapped__ - elif isinstance(function, partial): - function = function.func - else: - return function - - -def get_signature_parameters( - function: Callable[..., Any], globalns: Dict[str, Any] -) -> Dict[str, inspect.Parameter]: - signature = inspect.signature(function) - params = {} - cache: Dict[str, Any] = {} - eval_annotation = disnake.utils.evaluate_annotation - for name, parameter in signature.parameters.items(): - annotation = parameter.annotation - if annotation is parameter.empty: - params[name] = parameter - continue - if annotation is None: - params[name] = parameter.replace(annotation=type(None)) - continue - - annotation = eval_annotation(annotation, globalns, globalns, cache) - if annotation is Greedy: - raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") - - params[name] = parameter.replace(annotation=annotation) - - return params - - def wrap_callback(coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -410,7 +379,11 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None: except AttributeError: globalns = {} - self.params = get_signature_parameters(function, globalns) + params = get_signature_parameters(function, globalns) + for param in params.values(): + if param.annotation is Greedy: + raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") + self.params = params def add_check(self, func: Check) -> None: """Adds a check to the command. diff --git a/disnake/utils.py b/disnake/utils.py index 15b2f53ee0..cd10aa7748 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -6,6 +6,7 @@ import asyncio import datetime import functools +import inspect import json import os import pkgutil @@ -1200,6 +1201,38 @@ def resolve_annotation( return evaluate_annotation(annotation, globalns, locals, cache) +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, "__wrapped__"): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + + +def get_signature_parameters( + function: Callable[..., Any], globalns: Dict[str, Any] +) -> Dict[str, inspect.Parameter]: + signature = inspect.signature(function) + params = {} + cache: Dict[str, Any] = {} + for name, parameter in signature.parameters.items(): + annotation = parameter.annotation + if annotation is parameter.empty: + params[name] = parameter + continue + if annotation is None: + params[name] = parameter.replace(annotation=type(None)) + continue + + annotation = evaluate_annotation(annotation, globalns, globalns, cache) + params[name] = parameter.replace(annotation=annotation) + + return params + + TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]