Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify slash/prefix command signature evaluation #1116

Merged
merged 10 commits into from
Oct 26, 2023
1 change: 1 addition & 0 deletions changelog/1116.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Rewrite slash command signature evaluation to use the same mechanism as prefix command signatures. This should not have an impact on user code, but streamlines future changes.
49 changes: 11 additions & 38 deletions disnake/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)

import disnake
from disnake.utils import _generated, _overload_with_permissions
from disnake.utils import (
_generated,
_overload_with_permissions,
get_signature_parameters,
unwrap_function,
)

from ._types import _BaseCommand
from .cog import Cog
Expand Down Expand Up @@ -114,42 +119,6 @@
P = TypeVar("P")


def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function


def get_signature_parameters(
function: Callable[..., Any], globalns: Dict[str, Any]
) -> Dict[str, inspect.Parameter]:
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
eval_annotation = disnake.utils.evaluate_annotation
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
params[name] = parameter
continue
if annotation is None:
params[name] = parameter.replace(annotation=type(None))
continue

annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")

params[name] = parameter.replace(annotation=annotation)

return params


def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -410,7 +379,11 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None:
except AttributeError:
globalns = {}

self.params = get_signature_parameters(function, globalns)
params = get_signature_parameters(function, globalns)
for param in params.values():
if param.annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
self.params = params

def add_check(self, func: Check) -> None:
"""Adds a check to the command.
Expand Down
83 changes: 29 additions & 54 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import itertools
import math
import sys
import types
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
Expand All @@ -32,7 +33,6 @@
TypeVar,
Union,
get_origin,
get_type_hints,
)

import disnake
Expand All @@ -42,7 +42,7 @@
from disnake.ext import commands
from disnake.i18n import Localized
from disnake.interactions import ApplicationCommandInteraction
from disnake.utils import maybe_coroutine
from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine

from . import errors
from .converter import CONVERTER_MAPPING
Expand Down Expand Up @@ -143,37 +143,6 @@ def remove_optionals(annotation: Any) -> Any:
return annotation


def signature(func: Callable) -> inspect.Signature:
"""Get the signature with evaluated annotations wherever possible

This is equivalent to `signature(..., eval_str=True)` in python 3.10
"""
if sys.version_info >= (3, 10):
return inspect.signature(func, eval_str=True)

if inspect.isfunction(func) or inspect.ismethod(func):
typehints = get_type_hints(func)
else:
typehints = get_type_hints(func.__call__)

signature = inspect.signature(func)
parameters = []

for name, param in signature.parameters.items():
if isinstance(param.annotation, str):
param = param.replace(annotation=typehints.get(name, inspect.Parameter.empty))
if param.annotation is type(None):
param = param.replace(annotation=None)

parameters.append(param)

return_annotation = typehints.get("return", inspect.Parameter.empty)
if return_annotation is type(None):
return_annotation = None

return signature.replace(parameters=parameters, return_annotation=return_annotation)


def _xt_to_xe(xe: Optional[float], xt: Optional[float], direction: float = 1) -> Optional[float]:
"""Function for combining xt and xe

Expand Down Expand Up @@ -795,7 +764,14 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo
return True

def parse_converter_annotation(self, converter: Callable, fallback_annotation: Any) -> None:
_, parameters = isolate_self(signature(converter))
if isinstance(converter, (types.FunctionType, types.MethodType)):
converter_func = converter
else:
# if converter isn't a function/method, assume it's a callable object/type
# (we need `__call__` here to get the correct global namespace later, since
# classes do not have `__globals__`)
converter_func = converter.__call__
EQUENOS marked this conversation as resolved.
Show resolved Hide resolved
_, parameters = isolate_self(get_signature_parameters(converter_func))

if len(parameters) != 1:
raise TypeError(
Expand Down Expand Up @@ -858,9 +834,9 @@ def to_option(self) -> Option:
def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwargs: Any) -> T:
"""Calls a function without providing any extra unexpected arguments"""
MISSING: Any = object()
sig = signature(function)
parameters = get_signature_parameters(function)

kinds = {p.kind for p in sig.parameters.values()}
kinds = {p.kind for p in parameters.values()}
arb = {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
if arb.issubset(kinds):
raise TypeError(
Expand All @@ -874,7 +850,7 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa

for index, parameter, posarg in itertools.zip_longest(
itertools.count(),
sig.parameters.values(),
parameters.values(),
possible_args,
fillvalue=MISSING,
):
Expand Down Expand Up @@ -903,15 +879,15 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa


def isolate_self(
sig: inspect.Signature,
parameters: Dict[str, inspect.Parameter],
) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]:
"""Create parameters without self and the first interaction"""
parameters = dict(sig.parameters)
parametersl = list(sig.parameters.values())

if not parameters:
return (None, None), {}

parameters = dict(parameters) # shallow copy
parametersl = list(parameters.values())

cog_param: Optional[inspect.Parameter] = None
inter_param: Optional[inspect.Parameter] = None

Expand Down Expand Up @@ -961,19 +937,19 @@ def classify_autocompleter(autocompleter: AnyAutocompleter) -> None:

def collect_params(
function: Callable,
sig: Optional[inspect.Signature] = None,
parameters: Optional[Dict[str, inspect.Parameter]] = None,
) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]:
"""Collect all parameters in a function.

Optionally accepts an `inspect.Signature` object (as an optimization),
calls `signature(function)` if not provided.
Optionally accepts a `{str: inspect.Parameter}` dict as an optimization,
calls `get_signature_parameters(function)` if not provided.

Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`)
"""
if sig is None:
sig = signature(function)
if parameters is None:
parameters = get_signature_parameters(function)

(cog_param, inter_param), parameters = isolate_self(sig)
(cog_param, inter_param), parameters = isolate_self(parameters)

doc = disnake.utils.parse_docstring(function)["params"]

Expand Down Expand Up @@ -1097,10 +1073,10 @@ def expand_params(command: AnySlashCommand) -> List[Option]:

Returns the created options
"""
sig = signature(command.callback)
# pass `sig` down to avoid having to call `signature(func)` another time,
parameters = get_signature_parameters(command.callback)
# pass `parameters` down to avoid having to call `get_signature_parameters(func)` another time,
# which may cause side effects with deferred annotations and warnings
_, inter_param, params, injections = collect_params(command.callback, sig)
_, inter_param, params, injections = collect_params(command.callback, parameters)

if inter_param is None:
raise TypeError(f"Couldn't find an interaction parameter in {command.callback}")
Expand All @@ -1125,7 +1101,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
if param.autocomplete:
command.autocompleters[param.name] = param.autocomplete

if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction):
if issubclass_(parameters[inter_param].annotation, disnake.GuildCommandInteraction):
command._guild_only = True

return [param.to_option() for param in params]
Expand Down Expand Up @@ -1407,12 +1383,11 @@ def register_injection(
:class:`Injection`
The injection being registered.
"""
sig = signature(function)
tp = sig.return_annotation
tp = get_signature_return(function)

if tp is inspect.Parameter.empty:
raise TypeError("Injection must have a return annotation")
if tp in ParamInfo.TYPES:
raise TypeError("Injection cannot overwrite builtin types")

return Injection.register(function, sig.return_annotation, autocompleters=autocompleters)
return Injection.register(function, tp, autocompleters=autocompleters)
67 changes: 67 additions & 0 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import datetime
import functools
import inspect
import json
import os
import pkgutil
Expand Down Expand Up @@ -1203,6 +1204,72 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)


def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function


def _get_function_globals(function: Callable[..., Any]) -> Dict[str, Any]:
unwrap = unwrap_function(function)
try:
return unwrap.__globals__
except AttributeError:
return {}


_inspect_empty = inspect.Parameter.empty


def get_signature_parameters(
function: Callable[..., Any], globalns: Optional[Dict[str, Any]] = None
EQUENOS marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, inspect.Parameter]:
# if no globalns provided, unwrap (where needed) and get global namespace from there
if globalns is None:
globalns = _get_function_globals(function)

params: Dict[str, inspect.Parameter] = {}
cache: Dict[str, Any] = {}

signature = inspect.signature(function)

# eval all parameter annotations
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is _inspect_empty:
params[name] = parameter
continue

if annotation is None:
annotation = type(None)
else:
annotation = evaluate_annotation(annotation, globalns, globalns, cache)

params[name] = parameter.replace(annotation=annotation)

return params


def get_signature_return(function: Callable[..., Any]) -> Any:
signature = inspect.signature(function)

# same as parameters above, but for the return annotation
ret = signature.return_annotation
if ret is not _inspect_empty:
if ret is None:
ret = type(None)
else:
globalns = _get_function_globals(function)
ret = evaluate_annotation(ret, globalns, globalns, {})

return ret


TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]


Expand Down
10 changes: 5 additions & 5 deletions tests/ext/commands/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_isolate_self(self) -> None:
def func(a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is None
assert parameters == ({"a": mock.ANY})
Expand All @@ -80,7 +80,7 @@ def test_isolate_self_inter(self) -> None:
def func(i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand All @@ -89,7 +89,7 @@ def test_isolate_self_cog_inter(self) -> None:
def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is not None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand All @@ -98,7 +98,7 @@ def test_isolate_self_generic(self) -> None:
def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand All @@ -109,7 +109,7 @@ def func(
) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
(cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})
Expand Down