Skip to content

Commit

Permalink
fix(commands): handle interactions in union types correctly (#1121)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv authored Oct 26, 2023
1 parent 3591556 commit b23786b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 19 deletions.
1 change: 1 addition & 0 deletions changelog/1121.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`Interaction` and subtypes accept the bot type as a generic parameter to denote the type returned by the :attr:`~Interaction.bot` and :attr:`~Interaction.client` properties.
38 changes: 20 additions & 18 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Type,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
Expand Down Expand Up @@ -110,17 +109,26 @@


def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[TypeT]:
"""Similar to the builtin `issubclass`, but more lenient.
Can also handle unions (`issubclass(Union[int, str], int)`) and
generic types (`issubclass(X[T], X)`) in the first argument.
"""
if not isinstance(tp, (type, tuple)):
return False
elif not isinstance(obj, type):
# Assume we have a type hint
if get_origin(obj) in (Union, UnionType, Optional):
obj = get_args(obj)
return any(isinstance(o, type) and issubclass(o, tp) for o in obj)
else:
# Other type hint specializations are not supported
return False
return issubclass(obj, tp)
elif isinstance(obj, type):
# common case
return issubclass(obj, tp)

# At this point, `obj` is likely a generic type hint
if (origin := get_origin(obj)) is None:
return False

if origin in (Union, UnionType):
# If we have a Union, try matching any of its args
# (recursively, to handle possibly generic types inside this union)
return any(issubclass_(o, tp) for o in obj.__args__)
else:
return isinstance(origin, type) and issubclass(origin, tp)


def remove_optionals(annotation: Any) -> Any:
Expand Down Expand Up @@ -912,7 +920,6 @@ def isolate_self(
parametersl.pop(0)
if parametersl:
annot = parametersl[0].annotation
annot = get_origin(annot) or annot
if issubclass_(annot, ApplicationCommandInteraction) or annot is inspect.Parameter.empty:
inter_param = parameters.pop(parametersl[0].name)

Expand Down Expand Up @@ -984,9 +991,7 @@ def collect_params(
injections[parameter.name] = default
elif parameter.annotation in Injection._registered:
injections[parameter.name] = Injection._registered[parameter.annotation]
elif issubclass_(
get_origin(parameter.annotation) or parameter.annotation, ApplicationCommandInteraction
):
elif issubclass_(parameter.annotation, ApplicationCommandInteraction):
if inter_param is None:
inter_param = parameter
else:
Expand Down Expand Up @@ -1120,10 +1125,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
if param.autocomplete:
command.autocompleters[param.name] = param.autocomplete

if issubclass_(
get_origin(annot := sig.parameters[inter_param].annotation) or annot,
disnake.GuildCommandInteraction,
):
if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction):
command._guild_only = True

return [param.to_option() for param in params]
Expand Down
49 changes: 48 additions & 1 deletion tests/ext/commands/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import disnake
from disnake import Member, Role, User
from disnake.ext import commands
from disnake.ext.commands import params

OptionType = disnake.OptionType

Expand Down Expand Up @@ -66,6 +67,53 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None:
with pytest.raises(commands.errors.MemberNotFound):
await info.verify_type(mock.Mock(), arg_mock)

def test_isolate_self(self) -> None:
def func(a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_inter(self) -> None:
def func(i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_cog_inter(self) -> None:
def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is not None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_generic(self) -> None:
def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_union(self) -> None:
def func(
i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int
) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})


# this uses `Range` for testing `_BaseRange`, `String` should work equally
class TestBaseRange:
Expand Down Expand Up @@ -189,7 +237,6 @@ def test_string(self) -> None:
assert info.max_value is None
assert info.type == annotation.underlying_type

# uses lambdas since new union syntax isn't supported on all versions
@pytest.mark.parametrize(
"annotation_str",
[
Expand Down

0 comments on commit b23786b

Please sign in to comment.