Skip to content

Commit

Permalink
refactor: use custom signature machinery for slash parsing as well
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Oct 8, 2023
1 parent 84ea3f8 commit 2574d06
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 54 deletions.
75 changes: 21 additions & 54 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
Union,
get_args,
get_origin,
get_type_hints,
)

import disnake
Expand All @@ -43,7 +42,7 @@
from disnake.ext import commands
from disnake.i18n import Localized
from disnake.interactions import ApplicationCommandInteraction
from disnake.utils import maybe_coroutine
from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine

from . import errors
from .converter import CONVERTER_MAPPING
Expand Down Expand Up @@ -135,37 +134,6 @@ def remove_optionals(annotation: Any) -> Any:
return annotation


def signature(func: Callable) -> inspect.Signature:
"""Get the signature with evaluated annotations wherever possible
This is equivalent to `signature(..., eval_str=True)` in python 3.10
"""
if sys.version_info >= (3, 10):
return inspect.signature(func, eval_str=True)

if inspect.isfunction(func) or inspect.ismethod(func):
typehints = get_type_hints(func)
else:
typehints = get_type_hints(func.__call__)

signature = inspect.signature(func)
parameters = []

for name, param in signature.parameters.items():
if isinstance(param.annotation, str):
param = param.replace(annotation=typehints.get(name, inspect.Parameter.empty))
if param.annotation is type(None):
param = param.replace(annotation=None)

parameters.append(param)

return_annotation = typehints.get("return", inspect.Parameter.empty)
if return_annotation is type(None):
return_annotation = None

return signature.replace(parameters=parameters, return_annotation=return_annotation)


def _xt_to_xe(xe: Optional[float], xt: Optional[float], direction: float = 1) -> Optional[float]:
"""Function for combining xt and xe
Expand Down Expand Up @@ -787,7 +755,7 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo
return True

def parse_converter_annotation(self, converter: Callable, fallback_annotation: Any) -> None:
_, parameters = isolate_self(signature(converter))
_, parameters = isolate_self(get_signature_parameters(converter))

if len(parameters) != 1:
raise TypeError(
Expand Down Expand Up @@ -850,9 +818,9 @@ def to_option(self) -> Option:
def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwargs: Any) -> T:
"""Calls a function without providing any extra unexpected arguments"""
MISSING: Any = object()
sig = signature(function)
parameters = get_signature_parameters(function)

kinds = {p.kind for p in sig.parameters.values()}
kinds = {p.kind for p in parameters.values()}
arb = {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
if arb.issubset(kinds):
raise TypeError(
Expand All @@ -866,7 +834,7 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa

for index, parameter, posarg in itertools.zip_longest(
itertools.count(),
sig.parameters.values(),
parameters.values(),
possible_args,
fillvalue=MISSING,
):
Expand Down Expand Up @@ -895,15 +863,15 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa


def isolate_self(
sig: inspect.Signature,
parameters: Dict[str, inspect.Parameter],
) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]:
"""Create parameters without self and the first interaction"""
parameters = dict(sig.parameters)
parametersl = list(sig.parameters.values())

if not parameters:
return (None, None), {}

parameters = dict(parameters) # shallow copy
parametersl = list(parameters.values())

cog_param: Optional[inspect.Parameter] = None
inter_param: Optional[inspect.Parameter] = None

Expand Down Expand Up @@ -954,19 +922,19 @@ def classify_autocompleter(autocompleter: AnyAutocompleter) -> None:

def collect_params(
function: Callable,
sig: Optional[inspect.Signature] = None,
parameters: Optional[Dict[str, inspect.Parameter]] = None,
) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]:
"""Collect all parameters in a function.
Optionally accepts an `inspect.Signature` object (as an optimization),
calls `signature(function)` if not provided.
Optionally accepts a `{str: inspect.Parameter}` dict as an optimization,
calls `get_signature_parameters(function)` if not provided.
Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`)
"""
if sig is None:
sig = signature(function)
if parameters is None:
parameters = get_signature_parameters(function)

(cog_param, inter_param), parameters = isolate_self(sig)
(cog_param, inter_param), parameters = isolate_self(parameters)

doc = disnake.utils.parse_docstring(function)["params"]

Expand Down Expand Up @@ -1092,10 +1060,10 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
Returns the created options
"""
sig = signature(command.callback)
# pass `sig` down to avoid having to call `signature(func)` another time,
parameters = get_signature_parameters(command.callback)
# pass `parameters` down to avoid having to call `get_signature_parameters(func)` another time,
# which may cause side effects with deferred annotations and warnings
_, inter_param, params, injections = collect_params(command.callback, sig)
_, inter_param, params, injections = collect_params(command.callback, parameters)

if inter_param is None:
raise TypeError(f"Couldn't find an interaction parameter in {command.callback}")
Expand All @@ -1121,7 +1089,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
command.autocompleters[param.name] = param.autocomplete

if issubclass_(
get_origin(annot := sig.parameters[inter_param].annotation) or annot,
get_origin(annot := parameters[inter_param].annotation) or annot,
disnake.GuildCommandInteraction,
):
command._guild_only = True
Expand Down Expand Up @@ -1405,12 +1373,11 @@ def register_injection(
:class:`Injection`
The injection being registered.
"""
sig = signature(function)
tp = sig.return_annotation
tp = get_signature_return(function)

if tp is inspect.Parameter.empty:
raise TypeError("Injection must have a return annotation")
if tp in ParamInfo.TYPES:
raise TypeError("Injection cannot overwrite builtin types")

return Injection.register(function, sig.return_annotation, autocompleters=autocompleters)
return Injection.register(function, tp, autocompleters=autocompleters)
15 changes: 15 additions & 0 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,21 @@ def get_signature_parameters(
return params


def get_signature_return(function: Callable[..., Any]) -> Any:
signature = inspect.signature(function)

# same as parameters above, but for the return annotation
ret = signature.return_annotation
if ret is not _inspect_empty:
if ret is None:
ret = type(None)
else:
globalns = _get_function_globals(function)
ret = evaluate_annotation(ret, globalns, globalns, {})

return ret


TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]


Expand Down

0 comments on commit 2574d06

Please sign in to comment.