From 2c85e39f415e9d96867dc317592582b685888ed9 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 16 Nov 2023 15:37:08 +0100 Subject: [PATCH] feat(commands): don't parse self/ctx parameter annotations of prefix command callbacks (#847) --- changelog/847.feature.rst | 1 + disnake/ext/commands/core.py | 40 +---------- disnake/ext/commands/help.py | 10 --- disnake/ext/commands/params.py | 30 +++++--- disnake/utils.py | 70 ++++++++++++++++++- tests/ext/commands/test_params.py | 110 +++++++++++++++++------------- tests/test_utils.py | 82 ++++++++++++++++++++++ 7 files changed, 235 insertions(+), 108 deletions(-) create mode 100644 changelog/847.feature.rst diff --git a/changelog/847.feature.rst b/changelog/847.feature.rst new file mode 100644 index 0000000000..7418ed0783 --- /dev/null +++ b/changelog/847.feature.rst @@ -0,0 +1 @@ +|commands| Skip evaluating annotations of ``self`` (if present) and ``ctx`` parameters in prefix commands. These may now use stringified annotations with types that aren't available at runtime. diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index 2bb108e966..fda34b5a95 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -381,7 +381,7 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None: except AttributeError: globalns = {} - params = get_signature_parameters(function, globalns) + params = get_signature_parameters(function, globalns, skip_standard_params=True) for param in params.values(): if param.annotation is Greedy: raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") @@ -607,21 +607,7 @@ def clean_params(self) -> Dict[str, inspect.Parameter]: Useful for inspecting signature. """ - result = self.params.copy() - if self.cog is not None: - # first parameter is self - try: - del result[next(iter(result))] - except StopIteration: - raise ValueError("missing 'self' parameter") from None - - try: - # first/second parameter is context - del result[next(iter(result))] - except StopIteration: - raise ValueError("missing 'context' parameter") from None - - return result + return self.params.copy() @property def full_parent_name(self) -> str: @@ -693,27 +679,7 @@ async def _parse_arguments(self, ctx: Context) -> None: kwargs = ctx.kwargs view = ctx.view - iterator = iter(self.params.items()) - - if self.cog is not None: - # we have 'self' as the first parameter so just advance - # the iterator and resume parsing - try: - next(iterator) - except StopIteration: - raise disnake.ClientException( - f'Callback for {self.name} command is missing "self" parameter.' - ) from None - - # next we have the 'ctx' as the next parameter - try: - next(iterator) - except StopIteration: - raise disnake.ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) from None - - for name, param in iterator: + for name, param in self.params.items(): ctx.current_parameter = param if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): transformed = await self.transform(ctx, param) diff --git a/disnake/ext/commands/help.py b/disnake/ext/commands/help.py index 5841a8ba11..483d4f4bd2 100644 --- a/disnake/ext/commands/help.py +++ b/disnake/ext/commands/help.py @@ -202,16 +202,6 @@ async def _parse_arguments(self, ctx) -> None: async def _on_error_cog_implementation(self, dummy, ctx, error) -> None: await self._injected.on_help_command_error(ctx, error) - @property - def clean_params(self): - result = self.params.copy() - try: - del result[next(iter(result))] - except StopIteration: - raise ValueError("Missing context parameter") from None - else: - return result - def _inject_into_cog(self, cog) -> None: # Warning: hacky diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 5aae2de611..95679ed802 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -42,7 +42,12 @@ from disnake.ext import commands from disnake.i18n import Localized from disnake.interactions import ApplicationCommandInteraction -from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine +from disnake.utils import ( + get_signature_parameters, + get_signature_return, + maybe_coroutine, + signature_has_self_param, +) from . import errors from .converter import CONVERTER_MAPPING @@ -771,7 +776,7 @@ def parse_converter_annotation(self, converter: Callable, fallback_annotation: A # (we need `__call__` here to get the correct global namespace later, since # classes do not have `__globals__`) converter_func = converter.__call__ - _, parameters = isolate_self(get_signature_parameters(converter_func)) + _, parameters = isolate_self(converter_func) if len(parameters) != 1: raise TypeError( @@ -879,9 +884,16 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa def isolate_self( - parameters: Dict[str, inspect.Parameter], + function: Callable, + parameters: Optional[Dict[str, inspect.Parameter]] = None, ) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]: - """Create parameters without self and the first interaction""" + """Create parameters without self and the first interaction. + + Optionally accepts a `{str: inspect.Parameter}` dict as an optimization, + calls `get_signature_parameters(function)` if not provided. + """ + if parameters is None: + parameters = get_signature_parameters(function) if not parameters: return (None, None), {} @@ -891,7 +903,7 @@ def isolate_self( cog_param: Optional[inspect.Parameter] = None inter_param: Optional[inspect.Parameter] = None - if parametersl[0].name == "self": + if signature_has_self_param(function): cog_param = parameters.pop(parametersl[0].name) parametersl.pop(0) if parametersl: @@ -941,15 +953,11 @@ def collect_params( ) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]: """Collect all parameters in a function. - Optionally accepts a `{str: inspect.Parameter}` dict as an optimization, - calls `get_signature_parameters(function)` if not provided. + Optionally accepts a `{str: inspect.Parameter}` dict as an optimization. Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`) """ - if parameters is None: - parameters = get_signature_parameters(function) - - (cog_param, inter_param), parameters = isolate_self(parameters) + (cog_param, inter_param), parameters = isolate_self(function, parameters) doc = disnake.utils.parse_docstring(function)["params"] diff --git a/disnake/utils.py b/disnake/utils.py index d40cd4e8fe..a74d50ab94 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -12,6 +12,7 @@ import pkgutil import re import sys +import types import unicodedata import warnings from base64 import b64encode @@ -1227,7 +1228,10 @@ def _get_function_globals(function: Callable[..., Any]) -> Dict[str, Any]: def get_signature_parameters( - function: Callable[..., Any], globalns: Optional[Dict[str, Any]] = None + function: Callable[..., Any], + globalns: Optional[Dict[str, Any]] = None, + *, + skip_standard_params: bool = False, ) -> Dict[str, inspect.Parameter]: # if no globalns provided, unwrap (where needed) and get global namespace from there if globalns is None: @@ -1237,9 +1241,23 @@ def get_signature_parameters( cache: Dict[str, Any] = {} signature = inspect.signature(function) + iterator = iter(signature.parameters.items()) + + if skip_standard_params: + # skip `self` (if present) and `ctx` parameters, + # since their annotations are irrelevant + skip = 2 if signature_has_self_param(function) else 1 + + for _ in range(skip): + try: + next(iterator) + except StopIteration: + raise ValueError( + f"Expected command callback to have at least {skip} parameter(s)" + ) from None # eval all parameter annotations - for name, parameter in signature.parameters.items(): + for name, parameter in iterator: annotation = parameter.annotation if annotation is _inspect_empty: params[name] = parameter @@ -1270,6 +1288,54 @@ def get_signature_return(function: Callable[..., Any]) -> Any: return ret +def signature_has_self_param(function: Callable[..., Any]) -> bool: + # If a function was defined in a class and is not bound (i.e. is not types.MethodType), + # it should have a `self` parameter. + # Bound methods technically also have a `self` parameter, but this is + # used in conjunction with `inspect.signature`, which drops that parameter. + # + # There isn't really any way to reliably detect whether a function + # was defined in a class, other than `__qualname__`, thanks to PEP 3155. + # As noted in the PEP, this doesn't work with rebinding, but that should be a pretty rare edge case. + # + # + # There are a few possible situations here - for the purposes of this method, + # we want to detect the first case only: + # (1) The preceding component for *methods in classes* will be the class name, resulting in `Clazz.func`. + # (2) For *unbound* functions (not methods), `__qualname__ == __name__`. + # (3) Bound methods (i.e. types.MethodType) don't have a `self` parameter in the context of this function (see first paragraph). + # (we currently don't expect to handle bound methods anywhere, except the default help command implementation). + # (4) A somewhat special case are lambdas defined in a class namespace (but not inside a method), which use `Clazz.` and shouldn't match (1). + # (lambdas at class level are a bit funky; we currently only expect them in the `Param(converter=)` kwarg, which doesn't take a `self` parameter). + # (5) Similarly, *nested functions* use `containing_func..func` and shouldn't have a `self` parameter. + # + # Working solely based on this string is certainly not ideal, + # but the compiler does a bunch of processing just for that attribute, + # and there's really no other way to retrieve this information through other means later. + # (3.10: https://github.com/python/cpython/blob/e07086db03d2dc1cd2e2a24f6c9c0ddd422b4cf0/Python/compile.c#L744) + # + # Not reliable for classmethod/staticmethod. + + qname = function.__qualname__ + if qname == function.__name__: + # (2) + return False + + if isinstance(function, types.MethodType): + # (3) + return False + + # "a.b.c.d" => "a.b.c", "d" + parent, basename = qname.rsplit(".", 1) + + if basename == "": + # (4) + return False + + # (5) + return not parent.endswith(".") + + TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] diff --git a/tests/ext/commands/test_params.py b/tests/ext/commands/test_params.py index 61c812c8f0..96f2c08c32 100644 --- a/tests/ext/commands/test_params.py +++ b/tests/ext/commands/test_params.py @@ -10,7 +10,6 @@ import disnake from disnake import Member, Role, User from disnake.ext import commands -from disnake.ext.commands import params OptionType = disnake.OptionType @@ -67,53 +66,6 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None: with pytest.raises(commands.errors.MemberNotFound): await info.verify_type(mock.Mock(), arg_mock) - def test_isolate_self(self) -> None: - def func(a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_inter(self) -> None: - def func(i: disnake.ApplicationCommandInteraction, a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_cog_inter(self) -> None: - def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is not None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_generic(self) -> None: - def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_union(self) -> None: - def func( - i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int - ) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - # this uses `Range` for testing `_BaseRange`, `String` should work equally class TestBaseRange: @@ -260,3 +212,65 @@ def test_optional(self, annotation_str) -> None: assert info.min_value == 1 assert info.max_value == 2 assert info.type == int + + +class TestIsolateSelf: + def test_function_simple(self) -> None: + def func(a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None + assert inter is None + assert params.keys() == {"a"} + + def test_function_inter(self) -> None: + def func(inter: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None # should not be set + assert inter is not None + assert params.keys() == {"a"} + + def test_unbound_method(self) -> None: + class Cog(commands.Cog): + def func(self, inter: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(Cog.func) + assert cog is not None # *should* be set here + assert inter is not None + assert params.keys() == {"a"} + + # I don't think the param parsing logic ever handles bound methods, but testing for regressions anyway + def test_bound_method(self) -> None: + class Cog(commands.Cog): + def func(self, inter: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(Cog().func) + assert cog is None # should not be set here, since method is already bound + assert inter is not None + assert params.keys() == {"a"} + + def test_generic(self) -> None: + def func(inter: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None + assert inter is not None + assert params.keys() == {"a"} + + def test_inter_union(self) -> None: + def func( + inter: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], + a: int, + ) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None + assert inter is not None + assert params.keys() == {"a"} diff --git a/tests/test_utils.py b/tests/test_utils.py index a8f52e6b1f..d767264a95 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import asyncio import datetime +import functools import inspect import os import sys @@ -880,3 +881,84 @@ def test_as_valid_locale(locale, expected) -> None: ) def test_humanize_list(values, expected) -> None: assert utils.humanize_list(values, "plus") == expected + + +# used for `test_signature_has_self_param` +def _toplevel(): + def inner() -> None: + ... + + return inner + + +def decorator(f): + @functools.wraps(f) + def wrap(self, *args, **kwargs): + return f(self, *args, **kwargs) + + return wrap + + +# used for `test_signature_has_self_param` +class _Clazz: + def func(self): + def inner() -> None: + ... + + return inner + + @classmethod + def cmethod(cls) -> None: + ... + + @staticmethod + def smethod() -> None: + ... + + class Nested: + def func(self): + def inner() -> None: + ... + + return inner + + rebind = _toplevel + + @decorator + def decorated(self) -> None: + ... + + _lambda = lambda: None + + +@pytest.mark.parametrize( + ("function", "expected"), + [ + # top-level function + (_toplevel, False), + # methods in class + (_Clazz.func, True), + (_Clazz().func, False), + # unfortunately doesn't work + (_Clazz.rebind, False), + (_Clazz().rebind, False), + # classmethod/staticmethod isn't supported, but checked to ensure consistency + (_Clazz.cmethod, False), + (_Clazz.smethod, True), + # nested class methods + (_Clazz.Nested.func, True), + (_Clazz.Nested().func, False), + # inner methods + (_toplevel(), False), + (_Clazz().func(), False), + (_Clazz.Nested().func(), False), + # decorated method + (_Clazz.decorated, True), + (_Clazz().decorated, False), + # lambda (class-level) + (_Clazz._lambda, False), + (_Clazz()._lambda, False), + ], +) +def test_signature_has_self_param(function, expected) -> None: + assert utils.signature_has_self_param(function) == expected