Skip to content

Commit

Permalink
refactor: move signature utility functions from commands.core to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Oct 8, 2023
1 parent 1f6104d commit d0bb39d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
49 changes: 11 additions & 38 deletions disnake/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)

import disnake
from disnake.utils import _generated, _overload_with_permissions
from disnake.utils import (
_generated,
_overload_with_permissions,
get_signature_parameters,
unwrap_function,
)

from ._types import _BaseCommand
from .cog import Cog
Expand Down Expand Up @@ -114,42 +119,6 @@
P = TypeVar("P")


def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function


def get_signature_parameters(
function: Callable[..., Any], globalns: Dict[str, Any]
) -> Dict[str, inspect.Parameter]:
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
eval_annotation = disnake.utils.evaluate_annotation
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
params[name] = parameter
continue
if annotation is None:
params[name] = parameter.replace(annotation=type(None))
continue

annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")

params[name] = parameter.replace(annotation=annotation)

return params


def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -410,7 +379,11 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None:
except AttributeError:
globalns = {}

self.params = get_signature_parameters(function, globalns)
params = get_signature_parameters(function, globalns)
for param in params.values():
if param.annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
self.params = params

def add_check(self, func: Check) -> None:
"""Adds a check to the command.
Expand Down
33 changes: 33 additions & 0 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import datetime
import functools
import inspect
import json
import os
import pkgutil
Expand Down Expand Up @@ -1200,6 +1201,38 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)


def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, "__wrapped__"):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function


def get_signature_parameters(
function: Callable[..., Any], globalns: Dict[str, Any]
) -> Dict[str, inspect.Parameter]:
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
params[name] = parameter
continue
if annotation is None:
params[name] = parameter.replace(annotation=type(None))
continue

annotation = evaluate_annotation(annotation, globalns, globalns, cache)
params[name] = parameter.replace(annotation=annotation)

return params


TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]


Expand Down

0 comments on commit d0bb39d

Please sign in to comment.