Skip to content

Commit

Permalink
feat(commands): don't parse self/ctx parameter annotations of prefix …
Browse files Browse the repository at this point in the history
…command callbacks (#847)
  • Loading branch information
shiftinv authored Nov 16, 2023
1 parent 59b101f commit 2c85e39
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 108 deletions.
1 change: 1 addition & 0 deletions changelog/847.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Skip evaluating annotations of ``self`` (if present) and ``ctx`` parameters in prefix commands. These may now use stringified annotations with types that aren't available at runtime.
40 changes: 3 additions & 37 deletions disnake/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None:
except AttributeError:
globalns = {}

params = get_signature_parameters(function, globalns)
params = get_signature_parameters(function, globalns, skip_standard_params=True)
for param in params.values():
if param.annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
Expand Down Expand Up @@ -607,21 +607,7 @@ def clean_params(self) -> Dict[str, inspect.Parameter]:
Useful for inspecting signature.
"""
result = self.params.copy()
if self.cog is not None:
# first parameter is self
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError("missing 'self' parameter") from None

try:
# first/second parameter is context
del result[next(iter(result))]
except StopIteration:
raise ValueError("missing 'context' parameter") from None

return result
return self.params.copy()

@property
def full_parent_name(self) -> str:
Expand Down Expand Up @@ -693,27 +679,7 @@ async def _parse_arguments(self, ctx: Context) -> None:
kwargs = ctx.kwargs

view = ctx.view
iterator = iter(self.params.items())

if self.cog is not None:
# we have 'self' as the first parameter so just advance
# the iterator and resume parsing
try:
next(iterator)
except StopIteration:
raise disnake.ClientException(
f'Callback for {self.name} command is missing "self" parameter.'
) from None

# next we have the 'ctx' as the next parameter
try:
next(iterator)
except StopIteration:
raise disnake.ClientException(
f'Callback for {self.name} command is missing "ctx" parameter.'
) from None

for name, param in iterator:
for name, param in self.params.items():
ctx.current_parameter = param
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
transformed = await self.transform(ctx, param)
Expand Down
10 changes: 0 additions & 10 deletions disnake/ext/commands/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,6 @@ async def _parse_arguments(self, ctx) -> None:
async def _on_error_cog_implementation(self, dummy, ctx, error) -> None:
await self._injected.on_help_command_error(ctx, error)

@property
def clean_params(self):
result = self.params.copy()
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError("Missing context parameter") from None
else:
return result

def _inject_into_cog(self, cog) -> None:
# Warning: hacky

Expand Down
30 changes: 19 additions & 11 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
from disnake.ext import commands
from disnake.i18n import Localized
from disnake.interactions import ApplicationCommandInteraction
from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine
from disnake.utils import (
get_signature_parameters,
get_signature_return,
maybe_coroutine,
signature_has_self_param,
)

from . import errors
from .converter import CONVERTER_MAPPING
Expand Down Expand Up @@ -771,7 +776,7 @@ def parse_converter_annotation(self, converter: Callable, fallback_annotation: A
# (we need `__call__` here to get the correct global namespace later, since
# classes do not have `__globals__`)
converter_func = converter.__call__
_, parameters = isolate_self(get_signature_parameters(converter_func))
_, parameters = isolate_self(converter_func)

if len(parameters) != 1:
raise TypeError(
Expand Down Expand Up @@ -879,9 +884,16 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa


def isolate_self(
parameters: Dict[str, inspect.Parameter],
function: Callable,
parameters: Optional[Dict[str, inspect.Parameter]] = None,
) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]:
"""Create parameters without self and the first interaction"""
"""Create parameters without self and the first interaction.
Optionally accepts a `{str: inspect.Parameter}` dict as an optimization,
calls `get_signature_parameters(function)` if not provided.
"""
if parameters is None:
parameters = get_signature_parameters(function)
if not parameters:
return (None, None), {}

Expand All @@ -891,7 +903,7 @@ def isolate_self(
cog_param: Optional[inspect.Parameter] = None
inter_param: Optional[inspect.Parameter] = None

if parametersl[0].name == "self":
if signature_has_self_param(function):
cog_param = parameters.pop(parametersl[0].name)
parametersl.pop(0)
if parametersl:
Expand Down Expand Up @@ -941,15 +953,11 @@ def collect_params(
) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]:
"""Collect all parameters in a function.
Optionally accepts a `{str: inspect.Parameter}` dict as an optimization,
calls `get_signature_parameters(function)` if not provided.
Optionally accepts a `{str: inspect.Parameter}` dict as an optimization.
Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`)
"""
if parameters is None:
parameters = get_signature_parameters(function)

(cog_param, inter_param), parameters = isolate_self(parameters)
(cog_param, inter_param), parameters = isolate_self(function, parameters)

doc = disnake.utils.parse_docstring(function)["params"]

Expand Down
70 changes: 68 additions & 2 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pkgutil
import re
import sys
import types
import unicodedata
import warnings
from base64 import b64encode
Expand Down Expand Up @@ -1227,7 +1228,10 @@ def _get_function_globals(function: Callable[..., Any]) -> Dict[str, Any]:


def get_signature_parameters(
function: Callable[..., Any], globalns: Optional[Dict[str, Any]] = None
function: Callable[..., Any],
globalns: Optional[Dict[str, Any]] = None,
*,
skip_standard_params: bool = False,
) -> Dict[str, inspect.Parameter]:
# if no globalns provided, unwrap (where needed) and get global namespace from there
if globalns is None:
Expand All @@ -1237,9 +1241,23 @@ def get_signature_parameters(
cache: Dict[str, Any] = {}

signature = inspect.signature(function)
iterator = iter(signature.parameters.items())

if skip_standard_params:
# skip `self` (if present) and `ctx` parameters,
# since their annotations are irrelevant
skip = 2 if signature_has_self_param(function) else 1

for _ in range(skip):
try:
next(iterator)
except StopIteration:
raise ValueError(
f"Expected command callback to have at least {skip} parameter(s)"
) from None

# eval all parameter annotations
for name, parameter in signature.parameters.items():
for name, parameter in iterator:
annotation = parameter.annotation
if annotation is _inspect_empty:
params[name] = parameter
Expand Down Expand Up @@ -1270,6 +1288,54 @@ def get_signature_return(function: Callable[..., Any]) -> Any:
return ret


def signature_has_self_param(function: Callable[..., Any]) -> bool:
# If a function was defined in a class and is not bound (i.e. is not types.MethodType),
# it should have a `self` parameter.
# Bound methods technically also have a `self` parameter, but this is
# used in conjunction with `inspect.signature`, which drops that parameter.
#
# There isn't really any way to reliably detect whether a function
# was defined in a class, other than `__qualname__`, thanks to PEP 3155.
# As noted in the PEP, this doesn't work with rebinding, but that should be a pretty rare edge case.
#
#
# There are a few possible situations here - for the purposes of this method,
# we want to detect the first case only:
# (1) The preceding component for *methods in classes* will be the class name, resulting in `Clazz.func`.
# (2) For *unbound* functions (not methods), `__qualname__ == __name__`.
# (3) Bound methods (i.e. types.MethodType) don't have a `self` parameter in the context of this function (see first paragraph).
# (we currently don't expect to handle bound methods anywhere, except the default help command implementation).
# (4) A somewhat special case are lambdas defined in a class namespace (but not inside a method), which use `Clazz.<lambda>` and shouldn't match (1).
# (lambdas at class level are a bit funky; we currently only expect them in the `Param(converter=)` kwarg, which doesn't take a `self` parameter).
# (5) Similarly, *nested functions* use `containing_func.<locals>.func` and shouldn't have a `self` parameter.
#
# Working solely based on this string is certainly not ideal,
# but the compiler does a bunch of processing just for that attribute,
# and there's really no other way to retrieve this information through other means later.
# (3.10: https://github.com/python/cpython/blob/e07086db03d2dc1cd2e2a24f6c9c0ddd422b4cf0/Python/compile.c#L744)
#
# Not reliable for classmethod/staticmethod.

qname = function.__qualname__
if qname == function.__name__:
# (2)
return False

if isinstance(function, types.MethodType):
# (3)
return False

# "a.b.c.d" => "a.b.c", "d"
parent, basename = qname.rsplit(".", 1)

if basename == "<lambda>":
# (4)
return False

# (5)
return not parent.endswith(".<locals>")


TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]


Expand Down
110 changes: 62 additions & 48 deletions tests/ext/commands/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import disnake
from disnake import Member, Role, User
from disnake.ext import commands
from disnake.ext.commands import params

OptionType = disnake.OptionType

Expand Down Expand Up @@ -67,53 +66,6 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None:
with pytest.raises(commands.errors.MemberNotFound):
await info.verify_type(mock.Mock(), arg_mock)

def test_isolate_self(self) -> None:
def func(a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_inter(self) -> None:
def func(i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_cog_inter(self) -> None:
def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is not None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_generic(self) -> None:
def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_union(self) -> None:
def func(
i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int
) -> None:
...

(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})


# this uses `Range` for testing `_BaseRange`, `String` should work equally
class TestBaseRange:
Expand Down Expand Up @@ -260,3 +212,65 @@ def test_optional(self, annotation_str) -> None:
assert info.min_value == 1
assert info.max_value == 2
assert info.type == int


class TestIsolateSelf:
def test_function_simple(self) -> None:
def func(a: int) -> None:
...

(cog, inter), params = commands.params.isolate_self(func)
assert cog is None
assert inter is None
assert params.keys() == {"a"}

def test_function_inter(self) -> None:
def func(inter: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), params = commands.params.isolate_self(func)
assert cog is None # should not be set
assert inter is not None
assert params.keys() == {"a"}

def test_unbound_method(self) -> None:
class Cog(commands.Cog):
def func(self, inter: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), params = commands.params.isolate_self(Cog.func)
assert cog is not None # *should* be set here
assert inter is not None
assert params.keys() == {"a"}

# I don't think the param parsing logic ever handles bound methods, but testing for regressions anyway
def test_bound_method(self) -> None:
class Cog(commands.Cog):
def func(self, inter: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), params = commands.params.isolate_self(Cog().func)
assert cog is None # should not be set here, since method is already bound
assert inter is not None
assert params.keys() == {"a"}

def test_generic(self) -> None:
def func(inter: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None:
...

(cog, inter), params = commands.params.isolate_self(func)
assert cog is None
assert inter is not None
assert params.keys() == {"a"}

def test_inter_union(self) -> None:
def func(
inter: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]],
a: int,
) -> None:
...

(cog, inter), params = commands.params.isolate_self(func)
assert cog is None
assert inter is not None
assert params.keys() == {"a"}
Loading

0 comments on commit 2c85e39

Please sign in to comment.