Skip to content

Commit

Permalink
refactor: unify slash/prefix command signature evaluation (#1116)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv authored Oct 26, 2023
1 parent b23786b commit f2e5886
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 97 deletions.
1 change: 1 addition & 0 deletions changelog/1116.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Rewrite slash command signature evaluation to use the same mechanism as prefix command signatures. This should not have an impact on user code, but streamlines future changes.
49 changes: 11 additions & 38 deletions disnake/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)

import disnake
from disnake.utils import _generated, _overload_with_permissions
from disnake.utils import (
_generated,
_overload_with_permissions,
get_signature_parameters,
unwrap_function,
)

from ._types import _BaseCommand
from .cog import Cog
Expand Down Expand Up @@ -114,42 +119,6 @@
P = TypeVar("P")


def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function


def get_signature_parameters(
function: Callable[..., Any], globalns: Dict[str, Any]
) -> Dict[str, inspect.Parameter]:
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
eval_annotation = disnake.utils.evaluate_annotation
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
params[name] = parameter
continue
if annotation is None:
params[name] = parameter.replace(annotation=type(None))
continue

annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")

params[name] = parameter.replace(annotation=annotation)

return params


def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -410,7 +379,11 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None:
except AttributeError:
globalns = {}

self.params = get_signature_parameters(function, globalns)
params = get_signature_parameters(function, globalns)
for param in params.values():
if param.annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
self.params = params

def add_check(self, func: Check) -> None:
"""Adds a check to the command.
Expand Down
83 changes: 29 additions & 54 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import itertools
import math
import sys
import types
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
Expand All @@ -32,7 +33,6 @@
TypeVar,
Union,
get_origin,
get_type_hints,
)

import disnake
Expand All @@ -42,7 +42,7 @@
from disnake.ext import commands
from disnake.i18n import Localized
from disnake.interactions import ApplicationCommandInteraction
from disnake.utils import maybe_coroutine
from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine

from . import errors
from .converter import CONVERTER_MAPPING
Expand Down Expand Up @@ -143,37 +143,6 @@ def remove_optionals(annotation: Any) -> Any:
return annotation


def signature(func: Callable) -> inspect.Signature:
"""Get the signature with evaluated annotations wherever possible
This is equivalent to `signature(..., eval_str=True)` in python 3.10
"""
if sys.version_info >= (3, 10):
return inspect.signature(func, eval_str=True)

if inspect.isfunction(func) or inspect.ismethod(func):
typehints = get_type_hints(func)
else:
typehints = get_type_hints(func.__call__)

signature = inspect.signature(func)
parameters = []

for name, param in signature.parameters.items():
if isinstance(param.annotation, str):
param = param.replace(annotation=typehints.get(name, inspect.Parameter.empty))
if param.annotation is type(None):
param = param.replace(annotation=None)

parameters.append(param)

return_annotation = typehints.get("return", inspect.Parameter.empty)
if return_annotation is type(None):
return_annotation = None

return signature.replace(parameters=parameters, return_annotation=return_annotation)


def _xt_to_xe(xe: Optional[float], xt: Optional[float], direction: float = 1) -> Optional[float]:
"""Function for combining xt and xe
Expand Down Expand Up @@ -795,7 +764,14 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo
return True

def parse_converter_annotation(self, converter: Callable, fallback_annotation: Any) -> None:
_, parameters = isolate_self(signature(converter))
if isinstance(converter, (types.FunctionType, types.MethodType)):
converter_func = converter
else:
# if converter isn't a function/method, assume it's a callable object/type
# (we need `__call__` here to get the correct global namespace later, since
# classes do not have `__globals__`)
converter_func = converter.__call__
_, parameters = isolate_self(get_signature_parameters(converter_func))

if len(parameters) != 1:
raise TypeError(
Expand Down Expand Up @@ -858,9 +834,9 @@ def to_option(self) -> Option:
def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwargs: Any) -> T:
"""Calls a function without providing any extra unexpected arguments"""
MISSING: Any = object()
sig = signature(function)
parameters = get_signature_parameters(function)

kinds = {p.kind for p in sig.parameters.values()}
kinds = {p.kind for p in parameters.values()}
arb = {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
if arb.issubset(kinds):
raise TypeError(
Expand All @@ -874,7 +850,7 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa

for index, parameter, posarg in itertools.zip_longest(
itertools.count(),
sig.parameters.values(),
parameters.values(),
possible_args,
fillvalue=MISSING,
):
Expand Down Expand Up @@ -903,15 +879,15 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa


def isolate_self(
sig: inspect.Signature,
parameters: Dict[str, inspect.Parameter],
) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]:
"""Create parameters without self and the first interaction"""
parameters = dict(sig.parameters)
parametersl = list(sig.parameters.values())

if not parameters:
return (None, None), {}

parameters = dict(parameters) # shallow copy
parametersl = list(parameters.values())

cog_param: Optional[inspect.Parameter] = None
inter_param: Optional[inspect.Parameter] = None

Expand Down Expand Up @@ -961,19 +937,19 @@ def classify_autocompleter(autocompleter: AnyAutocompleter) -> None:

def collect_params(
function: Callable,
sig: Optional[inspect.Signature] = None,
parameters: Optional[Dict[str, inspect.Parameter]] = None,
) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]:
"""Collect all parameters in a function.
Optionally accepts an `inspect.Signature` object (as an optimization),
calls `signature(function)` if not provided.
Optionally accepts a `{str: inspect.Parameter}` dict as an optimization,
calls `get_signature_parameters(function)` if not provided.
Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`)
"""
if sig is None:
sig = signature(function)
if parameters is None:
parameters = get_signature_parameters(function)

(cog_param, inter_param), parameters = isolate_self(sig)
(cog_param, inter_param), parameters = isolate_self(parameters)

doc = disnake.utils.parse_docstring(function)["params"]

Expand Down Expand Up @@ -1097,10 +1073,10 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
Returns the created options
"""
sig = signature(command.callback)
# pass `sig` down to avoid having to call `signature(func)` another time,
parameters = get_signature_parameters(command.callback)
# pass `parameters` down to avoid having to call `get_signature_parameters(func)` another time,
# which may cause side effects with deferred annotations and warnings
_, inter_param, params, injections = collect_params(command.callback, sig)
_, inter_param, params, injections = collect_params(command.callback, parameters)

if inter_param is None:
raise TypeError(f"Couldn't find an interaction parameter in {command.callback}")
Expand All @@ -1125,7 +1101,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
if param.autocomplete:
command.autocompleters[param.name] = param.autocomplete

if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction):
if issubclass_(parameters[inter_param].annotation, disnake.GuildCommandInteraction):
command._guild_only = True

return [param.to_option() for param in params]
Expand Down Expand Up @@ -1407,12 +1383,11 @@ def register_injection(
:class:`Injection`
The injection being registered.
"""
sig = signature(function)
tp = sig.return_annotation
tp = get_signature_return(function)

if tp is inspect.Parameter.empty:
raise TypeError("Injection must have a return annotation")
if tp in ParamInfo.TYPES:
raise TypeError("Injection cannot overwrite builtin types")

return Injection.register(function, sig.return_annotation, autocompleters=autocompleters)
return Injection.register(function, tp, autocompleters=autocompleters)
67 changes: 67 additions & 0 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import datetime
import functools
import inspect
import json
import os
import pkgutil
Expand Down Expand Up @@ -1203,6 +1204,72 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)


def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function


def _get_function_globals(function: Callable[..., Any]) -> Dict[str, Any]:
unwrap = unwrap_function(function)
try:
return unwrap.__globals__
except AttributeError:
return {}


_inspect_empty = inspect.Parameter.empty


def get_signature_parameters(
function: Callable[..., Any], globalns: Optional[Dict[str, Any]] = None
) -> Dict[str, inspect.Parameter]:
# if no globalns provided, unwrap (where needed) and get global namespace from there
if globalns is None:
globalns = _get_function_globals(function)

params: Dict[str, inspect.Parameter] = {}
cache: Dict[str, Any] = {}

signature = inspect.signature(function)

# eval all parameter annotations
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is _inspect_empty:
params[name] = parameter
continue

if annotation is None:
annotation = type(None)
else:
annotation = evaluate_annotation(annotation, globalns, globalns, cache)

params[name] = parameter.replace(annotation=annotation)

return params


def get_signature_return(function: Callable[..., Any]) -> Any:
signature = inspect.signature(function)

# same as parameters above, but for the return annotation
ret = signature.return_annotation
if ret is not _inspect_empty:
if ret is None:
ret = type(None)
else:
globalns = _get_function_globals(function)
ret = evaluate_annotation(ret, globalns, globalns, {})

return ret


TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]


Expand Down
10 changes: 5 additions & 5 deletions tests/ext/commands/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_isolate_self(self) -> None:
def func(a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is None
assert parameters == ({"a": mock.ANY})
Expand All @@ -80,7 +80,7 @@ def test_isolate_self_inter(self) -> None:
def func(i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand All @@ -89,7 +89,7 @@ def test_isolate_self_cog_inter(self) -> None:
def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is not None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand All @@ -98,7 +98,7 @@ def test_isolate_self_generic(self) -> None:
def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand All @@ -109,7 +109,7 @@ def func(
) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand Down

0 comments on commit f2e5886

Please sign in to comment.