From 857dd2de11713cd9130efe685c2b040703e8e6eb Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:43:36 +0200 Subject: [PATCH 1/6] fix(threads): move runtime import required for `Thread.permissions_for` (#1124) --- changelog/1123.bugfix.rst | 1 + disnake/threads.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog/1123.bugfix.rst diff --git a/changelog/1123.bugfix.rst b/changelog/1123.bugfix.rst new file mode 100644 index 0000000000..625f275381 --- /dev/null +++ b/changelog/1123.bugfix.rst @@ -0,0 +1 @@ +Fix :meth:`Thread.permissions_for` not working in some cases due to an incorrect import. diff --git a/disnake/threads.py b/disnake/threads.py index 2457c5a879..d759e272b5 100644 --- a/disnake/threads.py +++ b/disnake/threads.py @@ -12,6 +12,7 @@ from .flags import ChannelFlags from .mixins import Hashable from .partial_emoji import PartialEmoji, _EmojiTag +from .permissions import Permissions from .utils import MISSING, _get_as_snowflake, _unique, parse_time, snowflake_time __all__ = ( @@ -31,7 +32,6 @@ from .guild import Guild from .member import Member from .message import Message, PartialMessage - from .permissions import Permissions from .role import Role from .state import ConnectionState from .types.snowflake import SnowflakeList From c2a8dde6261af318a6e79a69d4af60fdd0d0fc63 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:57:40 +0200 Subject: [PATCH 2/6] fix(commands): resolve unstringified annotation before caching (#1120) --- changelog/1120.bugfix.rst | 1 + disnake/utils.py | 13 +++++++------ tests/test_utils.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 changelog/1120.bugfix.rst diff --git a/changelog/1120.bugfix.rst b/changelog/1120.bugfix.rst new file mode 100644 index 0000000000..146ac4a8de --- /dev/null +++ b/changelog/1120.bugfix.rst @@ -0,0 +1 @@ +|commands| Fix edge case in evaluation of multiple identical annotations with forwardrefs in a single signature. diff --git a/disnake/utils.py b/disnake/utils.py index 95f35003ce..1d06f137d2 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1134,13 +1134,14 @@ def evaluate_annotation( if implicit_str and isinstance(tp, str): if tp in cache: return cache[tp] - evaluated = ( - eval( # noqa: PGH001, S307 # this is how annotations are supposed to be unstringifed - tp, globals, locals - ) - ) + + # this is how annotations are supposed to be unstringifed + evaluated = eval(tp, globals, locals) # noqa: PGH001, S307 + # recurse to resolve nested args further + evaluated = evaluate_annotation(evaluated, globals, locals, cache) + cache[tp] = evaluated - return evaluate_annotation(evaluated, globals, locals, cache) + return evaluated if hasattr(tp, "__args__"): implicit_str = True diff --git a/tests/test_utils.py b/tests/test_utils.py index 48ef75134a..a8f52e6b1f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -758,8 +758,8 @@ def test_normalise_optional_params(params, expected) -> None: ("Tuple[dict, List[Literal[42, 99]]]", Tuple[dict, List[Literal[42, 99]]], True), # 3.10 union syntax pytest.param( - "int | Literal[False]", - Union[int, Literal[False]], + "int | float", + Union[int, float], True, marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="syntax requires py3.10"), ), From ffb7526c823f002c39434182c56e765201f37bfd Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 26 Oct 2023 15:35:54 +0200 Subject: [PATCH 3/6] feat(lint): enable ruff `TCH` (flake8-type-checking) rules (#1125) --- disnake/ext/commands/common_bot_base.py | 3 +-- docs/extensions/attributetable.py | 3 ++- docs/extensions/builder.py | 3 ++- docs/extensions/exception_hierarchy.py | 3 ++- docs/extensions/fulltoc.py | 3 ++- docs/extensions/nitpick_file_ignorer.py | 3 ++- docs/extensions/redirects.py | 7 +++++-- docs/extensions/resourcelinks.py | 3 ++- pyproject.toml | 9 ++++++++- tests/ext/commands/test_core.py | 12 +----------- tests/helpers.py | 9 +++++++++ tests/ui/test_action_row.py | 22 ++++++++++++---------- 12 files changed, 48 insertions(+), 32 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 841c3df837..f0d8fd5566 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -4,6 +4,7 @@ import asyncio import collections.abc +import importlib.machinery import importlib.util import logging import os @@ -19,8 +20,6 @@ from .cog import Cog if TYPE_CHECKING: - import importlib.machinery - from ._types import CoroFunc from .bot import AutoShardedBot, AutoShardedInteractionBot, Bot, InteractionBot from .help import HelpCommand diff --git a/docs/extensions/attributetable.py b/docs/extensions/attributetable.py index 1a66a0c026..d718c73604 100644 --- a/docs/extensions/attributetable.py +++ b/docs/extensions/attributetable.py @@ -14,12 +14,13 @@ from sphinx.util.docutils import SphinxDirective if TYPE_CHECKING: - from _types import SphinxExtensionMeta from sphinx.application import Sphinx from sphinx.environment import BuildEnvironment from sphinx.util.typing import OptionSpec from sphinx.writers.html import HTMLTranslator + from ._types import SphinxExtensionMeta + class attributetable(nodes.General, nodes.Element): pass diff --git a/docs/extensions/builder.py b/docs/extensions/builder.py index 6f1a5493d4..5133af0f85 100644 --- a/docs/extensions/builder.py +++ b/docs/extensions/builder.py @@ -8,12 +8,13 @@ from sphinx.environment.adapters.indexentries import IndexEntries if TYPE_CHECKING: - from _types import SphinxExtensionMeta from docutils import nodes from sphinx.application import Sphinx from sphinx.config import Config from sphinx.writers.html5 import HTML5Translator + from ._types import SphinxExtensionMeta + if TYPE_CHECKING: translator_base = HTML5Translator else: diff --git a/docs/extensions/exception_hierarchy.py b/docs/extensions/exception_hierarchy.py index 147a175af2..67040643ae 100644 --- a/docs/extensions/exception_hierarchy.py +++ b/docs/extensions/exception_hierarchy.py @@ -7,10 +7,11 @@ from docutils.parsers.rst import Directive if TYPE_CHECKING: - from _types import SphinxExtensionMeta from sphinx.application import Sphinx from sphinx.writers.html import HTMLTranslator + from ._types import SphinxExtensionMeta + class exception_hierarchy(nodes.General, nodes.Element): pass diff --git a/docs/extensions/fulltoc.py b/docs/extensions/fulltoc.py index 1d7523e52a..e35cd79514 100644 --- a/docs/extensions/fulltoc.py +++ b/docs/extensions/fulltoc.py @@ -31,7 +31,6 @@ from typing import TYPE_CHECKING, List, cast -from _types import SphinxExtensionMeta from docutils import nodes from sphinx import addnodes @@ -40,6 +39,8 @@ from sphinx.builders.html import StandaloneHTMLBuilder from sphinx.environment import BuildEnvironment + from ._types import SphinxExtensionMeta + # {prefix: index_doc} mapping # Any document that matches `prefix` will use `index_doc`'s toctree instead. GROUPED_SECTIONS = {"api/": "api/index", "ext/commands/api/": "ext/commands/api/index"} diff --git a/docs/extensions/nitpick_file_ignorer.py b/docs/extensions/nitpick_file_ignorer.py index da967f9d92..cc9eab588f 100644 --- a/docs/extensions/nitpick_file_ignorer.py +++ b/docs/extensions/nitpick_file_ignorer.py @@ -7,9 +7,10 @@ from sphinx.util import logging as sphinx_logging if TYPE_CHECKING: - from _types import SphinxExtensionMeta from sphinx.application import Sphinx + from ._types import SphinxExtensionMeta + class NitpickFileIgnorer(logging.Filter): def __init__(self, app: Sphinx) -> None: diff --git a/docs/extensions/redirects.py b/docs/extensions/redirects.py index fea63483be..44d1c16bef 100644 --- a/docs/extensions/redirects.py +++ b/docs/extensions/redirects.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: MIT +from __future__ import annotations import json from pathlib import Path -from typing import Dict +from typing import TYPE_CHECKING, Dict -from _types import SphinxExtensionMeta from sphinx.application import Sphinx from sphinx.util.fileutil import copy_asset_file +if TYPE_CHECKING: + from ._types import SphinxExtensionMeta + SCRIPT_PATH = "_templates/api_redirect.js_t" diff --git a/docs/extensions/resourcelinks.py b/docs/extensions/resourcelinks.py index 76a57b4656..d93f6f2715 100644 --- a/docs/extensions/resourcelinks.py +++ b/docs/extensions/resourcelinks.py @@ -10,12 +10,13 @@ from sphinx.util.nodes import split_explicit_title if TYPE_CHECKING: - from _types import SphinxExtensionMeta from docutils.nodes import Node, system_message from docutils.parsers.rst.states import Inliner from sphinx.application import Sphinx from sphinx.util.typing import RoleFunction + from ._types import SphinxExtensionMeta + def make_link_role(resource_links: Dict[str, str]) -> RoleFunction: def role( diff --git a/pyproject.toml b/pyproject.toml index a2e0854836..984bdf767f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,7 +149,7 @@ select = [ # "RET", # flake8-return # "SIM", # flake8-simplify "TID251", # flake8-tidy-imports, replaces S404 - # "TCH", # flake8-type-checking + "TCH", # flake8-type-checking "RUF", # ruff specific exceptions "PT", # flake8-pytest-style "Q", # flake8-quotes @@ -198,6 +198,13 @@ ignore = [ # outer loop variables are overwritten by inner assignment target, these are mostly intentional "PLW2901", + # ignore imports that could be moved into type-checking blocks + # (no real advantage other than possibly avoiding cycles, + # but can be dangerous in places where we need to parse signatures) + "TCH001", + "TCH002", + "TCH003", + # temporary disables, to fix later "D205", # blank line required between summary and description "D401", # first line of docstring should be in imperative mood diff --git a/tests/ext/commands/test_core.py b/tests/ext/commands/test_core.py index 1d3076a845..2b29f51988 100644 --- a/tests/ext/commands/test_core.py +++ b/tests/ext/commands/test_core.py @@ -1,20 +1,10 @@ # SPDX-License-Identifier: MIT -from typing import TYPE_CHECKING +from typing_extensions import assert_type from disnake.ext import commands from tests.helpers import reveal_type -if TYPE_CHECKING: - from typing_extensions import assert_type - - # NOTE: using undocumented `expected_text` parameter of pyright instead of `assert_type`, - # as `assert_type` can't handle bound ParamSpecs - reveal_type( - 42, # type: ignore - expected_text="str", # type: ignore - ) - class CustomContext(commands.Context): ... diff --git a/tests/helpers.py b/tests/helpers.py index 2d5a4d8e41..8e22e0cd08 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,6 +16,15 @@ def reveal_type(*args, **kwargs) -> None: raise RuntimeError +if TYPE_CHECKING: + # NOTE: using undocumented `expected_text` parameter of pyright instead of `assert_type`, + # as `assert_type` can't handle bound ParamSpecs + reveal_type( + 42, # type: ignore # suppress "revealed type is ..." output + expected_text="str", # type: ignore # ensure the functionality we want still works as expected + ) + + CallableT = TypeVar("CallableT", bound=Callable) diff --git a/tests/ui/test_action_row.py b/tests/ui/test_action_row.py index 9e72ecc3eb..f9c40ffedc 100644 --- a/tests/ui/test_action_row.py +++ b/tests/ui/test_action_row.py @@ -4,17 +4,20 @@ from unittest import mock import pytest +from typing_extensions import assert_type import disnake -from disnake.ui import ActionRow, Button, StringSelect, TextInput, WrappedComponent +from disnake.ui import ( + ActionRow, + Button, + MessageUIComponent, + ModalUIComponent, + StringSelect, + TextInput, + WrappedComponent, +) from disnake.ui.action_row import components_to_dict, components_to_rows -if TYPE_CHECKING: - from typing_extensions import assert_type - - from disnake.ui import MessageUIComponent, ModalUIComponent - - button1 = Button() button2 = Button() button3 = Button() @@ -133,9 +136,8 @@ def test_with_components(self) -> None: row_msg = ActionRow.with_message_components() assert list(row_msg.children) == [] - if TYPE_CHECKING: - assert_type(row_modal, ActionRow[ModalUIComponent]) - assert_type(row_msg, ActionRow[MessageUIComponent]) + assert_type(row_modal, ActionRow[ModalUIComponent]) + assert_type(row_msg, ActionRow[MessageUIComponent]) def test_rows_from_message(self) -> None: rows = [ From 35915569531552d93145dd7e4ceff202f0c1a70f Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 26 Oct 2023 15:53:04 +0200 Subject: [PATCH 4/6] docs: make `Supported Operations` container collapsible (#1126) --- changelog/1126.doc.rst | 1 + disnake/activity.py | 8 ++-- disnake/asset.py | 2 +- disnake/audit_logs.py | 2 +- disnake/channel.py | 16 +++---- disnake/colour.py | 2 +- disnake/embeds.py | 2 +- disnake/emoji.py | 2 +- disnake/ext/commands/flag_converter.py | 2 +- disnake/ext/commands/flags.py | 2 +- disnake/ext/commands/help.py | 2 +- disnake/flags.py | 22 +++++----- disnake/guild.py | 2 +- disnake/guild_scheduled_event.py | 2 +- disnake/invite.py | 6 +-- disnake/member.py | 2 +- disnake/message.py | 6 +-- disnake/object.py | 2 +- disnake/partial_emoji.py | 2 +- disnake/permissions.py | 4 +- disnake/reaction.py | 2 +- disnake/role.py | 2 +- disnake/stage_instance.py | 2 +- disnake/sticker.py | 10 ++--- disnake/team.py | 2 +- disnake/threads.py | 6 +-- disnake/ui/action_row.py | 2 +- disnake/user.py | 4 +- disnake/voice_region.py | 2 +- disnake/webhook/async_.py | 2 +- disnake/webhook/sync.py | 2 +- disnake/widget.py | 6 +-- docs/_static/style.css | 15 +++---- docs/api/audit_logs.rst | 2 +- docs/api/guilds.rst | 12 ++---- docs/api/messages.rst | 10 ++--- docs/api/misc.rst | 2 +- docs/conf.py | 1 + docs/extensions/collapse.py | 60 ++++++++++++++++++++++++++ 39 files changed, 145 insertions(+), 88 deletions(-) create mode 100644 changelog/1126.doc.rst create mode 100644 docs/extensions/collapse.py diff --git a/changelog/1126.doc.rst b/changelog/1126.doc.rst new file mode 100644 index 0000000000..44fa13bf31 --- /dev/null +++ b/changelog/1126.doc.rst @@ -0,0 +1 @@ +Make all "Supported Operations" container elements collapsible. diff --git a/disnake/activity.py b/disnake/activity.py index a213bf5a75..92460cd35d 100644 --- a/disnake/activity.py +++ b/disnake/activity.py @@ -404,7 +404,7 @@ class Game(BaseActivity): This is typically displayed via **Playing** on the official Discord client. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -487,7 +487,7 @@ class Streaming(BaseActivity): This is typically displayed via **Streaming** on the official Discord client. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -597,7 +597,7 @@ def __hash__(self) -> int: class Spotify(_BaseActivity): """Represents a Spotify listening activity from Discord. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -770,7 +770,7 @@ def party_id(self) -> str: class CustomActivity(BaseActivity): """Represents a Custom activity from Discord. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/asset.py b/disnake/asset.py index fc8fa6c7ea..fad72c79ce 100644 --- a/disnake/asset.py +++ b/disnake/asset.py @@ -164,7 +164,7 @@ async def to_file( class Asset(AssetMixin): """Represents a CDN asset on Discord. - .. container:: operations + .. collapse:: operations .. describe:: str(x) diff --git a/disnake/audit_logs.py b/disnake/audit_logs.py index 9d45912cd9..e8ab022edf 100644 --- a/disnake/audit_logs.py +++ b/disnake/audit_logs.py @@ -517,7 +517,7 @@ class AuditLogEntry(Hashable): You can retrieve these via :meth:`Guild.audit_logs`, or via the :func:`on_audit_log_entry_create` event. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/channel.py b/disnake/channel.py index 263735e24e..7eef52b942 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -103,7 +103,7 @@ async def _single_delete_strategy(messages: Iterable[Message]) -> None: class TextChannel(disnake.abc.Messageable, disnake.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1217,7 +1217,7 @@ def permissions_for( class VoiceChannel(disnake.abc.Messageable, VocalGuildChannel): """Represents a Discord guild voice channel. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1871,7 +1871,7 @@ class StageChannel(disnake.abc.Messageable, VocalGuildChannel): .. versionadded:: 1.7 - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -2696,7 +2696,7 @@ class CategoryChannel(disnake.abc.GuildChannel, Hashable): These are useful to group channels to logical compartments. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -3145,7 +3145,7 @@ class ForumChannel(disnake.abc.GuildChannel, Hashable): .. versionadded:: 2.5 - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -4184,7 +4184,7 @@ def get_tag_by_name(self, name: str, /) -> Optional[ForumTag]: class DMChannel(disnake.abc.Messageable, Hashable): """Represents a Discord direct message channel. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -4347,7 +4347,7 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: class GroupChannel(disnake.abc.Messageable, Hashable): """Represents a Discord group channel. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -4506,7 +4506,7 @@ class PartialMessageable(disnake.abc.Messageable, Hashable): .. versionadded:: 2.0 - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/colour.py b/disnake/colour.py index 82e8ef1bb3..4bd6585ea2 100644 --- a/disnake/colour.py +++ b/disnake/colour.py @@ -22,7 +22,7 @@ class Colour: There is an alias for this called Color. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/embeds.py b/disnake/embeds.py index 22ae7398af..1866d8d7eb 100644 --- a/disnake/embeds.py +++ b/disnake/embeds.py @@ -112,7 +112,7 @@ class _EmbedAuthorProxy(Sized, Protocol): class Embed: """Represents a Discord embed. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/emoji.py b/disnake/emoji.py index fb5ee1c3b4..0f3d02c27d 100644 --- a/disnake/emoji.py +++ b/disnake/emoji.py @@ -28,7 +28,7 @@ class Emoji(_EmojiTag, AssetMixin): Depending on the way this object was created, some of the attributes can have a value of ``None``. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/ext/commands/flag_converter.py b/disnake/ext/commands/flag_converter.py index 39a4b54808..37c97936c0 100644 --- a/disnake/ext/commands/flag_converter.py +++ b/disnake/ext/commands/flag_converter.py @@ -435,7 +435,7 @@ class FlagConverter(metaclass=FlagsMeta): how this converter works, check the appropriate :ref:`documentation `. - .. container:: operations + .. collapse:: operations .. describe:: iter(x) diff --git a/disnake/ext/commands/flags.py b/disnake/ext/commands/flags.py index ade3e79182..866566af3b 100644 --- a/disnake/ext/commands/flags.py +++ b/disnake/ext/commands/flags.py @@ -25,7 +25,7 @@ class CommandSyncFlags(BaseFlags): .. versionadded:: 2.7 - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/ext/commands/help.py b/disnake/ext/commands/help.py index 25a8c247dc..5841a8ba11 100644 --- a/disnake/ext/commands/help.py +++ b/disnake/ext/commands/help.py @@ -48,7 +48,7 @@ class Paginator: """A class that aids in paginating code blocks for Discord messages. - .. container:: operations + .. collapse:: operations .. describe:: len(x) diff --git a/disnake/flags.py b/disnake/flags.py index 66b9b4b369..63fc6bf2c6 100644 --- a/disnake/flags.py +++ b/disnake/flags.py @@ -329,7 +329,7 @@ class SystemChannelFlags(BaseFlags, inverted=True): to enable or disable. Arguments are applied in order, similar to :class:`Permissions`. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -491,7 +491,7 @@ class MessageFlags(BaseFlags): See :class:`SystemChannelFlags`. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -681,7 +681,7 @@ def is_voice_message(self): class PublicUserFlags(BaseFlags): """Wraps up the Discord User Public flags. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -930,7 +930,7 @@ class Intents(BaseFlags): .. versionadded:: 1.5 - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1617,7 +1617,7 @@ class MemberCacheFlags(BaseFlags): .. versionadded:: 1.5 - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1793,7 +1793,7 @@ def _voice_only(self): class ApplicationFlags(BaseFlags): """Wraps up the Discord Application flags. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1968,7 +1968,7 @@ def application_command_badge(self): class ChannelFlags(BaseFlags): """Wraps up the Discord Channel flags. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -2081,7 +2081,7 @@ def require_tag(self): class AutoModKeywordPresets(ListBaseFlags): """Wraps up the pre-defined auto moderation keyword lists, provided by Discord. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -2194,7 +2194,7 @@ def slurs(self): class MemberFlags(BaseFlags): """Wraps up Discord Member flags. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -2296,7 +2296,7 @@ def started_onboarding(self): class RoleFlags(BaseFlags): """Wraps up Discord Role flags. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -2376,7 +2376,7 @@ def in_prompt(self): class AttachmentFlags(BaseFlags): """Wraps up Discord Attachment flags. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/guild.py b/disnake/guild.py index 888c7518d4..3927992fb5 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -130,7 +130,7 @@ class Guild(Hashable): This is referred to as a "server" in the official Discord UI. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/guild_scheduled_event.py b/disnake/guild_scheduled_event.py index a9739b217a..63b23620fe 100644 --- a/disnake/guild_scheduled_event.py +++ b/disnake/guild_scheduled_event.py @@ -77,7 +77,7 @@ class GuildScheduledEvent(Hashable): .. versionadded:: 2.3 - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/invite.py b/disnake/invite.py index 2d95ea6d8a..a936c832b1 100644 --- a/disnake/invite.py +++ b/disnake/invite.py @@ -48,7 +48,7 @@ class PartialInviteChannel: guild the :class:`Invite` resolves to. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -137,7 +137,7 @@ class PartialInviteGuild: This model will be given when the user is not part of the guild the :class:`Invite` resolves to. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -256,7 +256,7 @@ class Invite(Hashable): Depending on the way this object was created, some of the attributes can have a value of ``None`` (see table below). - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/member.py b/disnake/member.py index 25886079eb..fb1a98e6c8 100644 --- a/disnake/member.py +++ b/disnake/member.py @@ -212,7 +212,7 @@ class Member(disnake.abc.Messageable, _UserTag): This implements a lot of the functionality of :class:`User`. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/message.py b/disnake/message.py index 21f59e269e..92aba532c7 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -220,7 +220,7 @@ async def _edit_handler( class Attachment(Hashable): """Represents an attachment from Discord. - .. container:: operations + .. collapse:: operations .. describe:: str(x) @@ -766,7 +766,7 @@ def flatten_handlers(cls): class Message(Hashable): """Represents a message from Discord. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -2177,7 +2177,7 @@ class PartialMessage(Hashable): .. versionadded:: 1.6 - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/object.py b/disnake/object.py index 9af6a758b7..cd3048b6b1 100644 --- a/disnake/object.py +++ b/disnake/object.py @@ -29,7 +29,7 @@ class Object(Hashable): receive this class rather than the actual data class. These cases are extremely rare. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/partial_emoji.py b/disnake/partial_emoji.py index ab124d28e1..92656bb314 100644 --- a/disnake/partial_emoji.py +++ b/disnake/partial_emoji.py @@ -38,7 +38,7 @@ class PartialEmoji(_EmojiTag, AssetMixin): - "Raw" data events such as :func:`on_raw_reaction_add` - Custom emoji that the bot cannot see from e.g. :attr:`Message.reactions` - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/permissions.py b/disnake/permissions.py index 8046f14d4a..a7df815caa 100644 --- a/disnake/permissions.py +++ b/disnake/permissions.py @@ -76,7 +76,7 @@ class Permissions(BaseFlags): You can now use keyword arguments to initialize :class:`Permissions` similar to :meth:`update`. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1036,7 +1036,7 @@ class PermissionOverwrite: The values supported by this are the same as :class:`Permissions` with the added possibility of it being set to ``None``. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/reaction.py b/disnake/reaction.py index 5a3c784627..0720759f6a 100644 --- a/disnake/reaction.py +++ b/disnake/reaction.py @@ -22,7 +22,7 @@ class Reaction: Depending on the way this object was created, some of the attributes can have a value of ``None``. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/role.py b/disnake/role.py index addd6b7551..89fa55804f 100644 --- a/disnake/role.py +++ b/disnake/role.py @@ -140,7 +140,7 @@ def __repr__(self) -> str: class Role(Hashable): """Represents a Discord role in a :class:`Guild`. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/stage_instance.py b/disnake/stage_instance.py index 08f50dc3e1..deff882916 100644 --- a/disnake/stage_instance.py +++ b/disnake/stage_instance.py @@ -24,7 +24,7 @@ class StageInstance(Hashable): .. versionadded:: 2.0 - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/sticker.py b/disnake/sticker.py index 01ce53b9d3..0d94c1ebcc 100644 --- a/disnake/sticker.py +++ b/disnake/sticker.py @@ -44,7 +44,7 @@ class StickerPack(Hashable): .. versionchanged:: 2.8 :attr:`cover_sticker_id`, :attr:`cover_sticker` and :attr:`banner` are now optional. - .. container:: operations + .. collapse:: operations .. describe:: str(x) @@ -163,7 +163,7 @@ class StickerItem(_StickerTag): .. versionadded:: 2.0 - .. container:: operations + .. collapse:: operations .. describe:: str(x) @@ -226,7 +226,7 @@ class Sticker(_StickerTag): .. versionadded:: 1.6 - .. container:: operations + .. collapse:: operations .. describe:: str(x) @@ -283,7 +283,7 @@ class StandardSticker(Sticker): .. versionadded:: 2.0 - .. container:: operations + .. collapse:: operations .. describe:: str(x) @@ -362,7 +362,7 @@ class GuildSticker(Sticker): .. versionadded:: 2.0 - .. container:: operations + .. collapse:: operations .. describe:: str(x) diff --git a/disnake/team.py b/disnake/team.py index 1034904cd9..dd0ee48d76 100644 --- a/disnake/team.py +++ b/disnake/team.py @@ -77,7 +77,7 @@ def owner(self) -> Optional[TeamMember]: class TeamMember(BaseUser): """Represents a team member in a team. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/threads.py b/disnake/threads.py index d759e272b5..fb0b2add92 100644 --- a/disnake/threads.py +++ b/disnake/threads.py @@ -54,7 +54,7 @@ class Thread(Messageable, Hashable): """Represents a Discord thread. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1018,7 +1018,7 @@ def _pop_member(self, member_id: int) -> Optional[ThreadMember]: class ThreadMember(Hashable): """Represents a Discord thread member. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -1092,7 +1092,7 @@ def thread(self) -> Thread: class ForumTag(Hashable): """Represents a tag for threads in forum channels. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/ui/action_row.py b/disnake/ui/action_row.py index b8473badb0..fe7244a776 100644 --- a/disnake/ui/action_row.py +++ b/disnake/ui/action_row.py @@ -91,7 +91,7 @@ class ActionRow(Generic[UIComponentT]): """Represents a UI action row. Useful for lower level component manipulation. - .. container:: operations + .. collapse:: operations .. describe:: x[i] diff --git a/disnake/user.py b/disnake/user.py index b2b05acb54..4326016100 100644 --- a/disnake/user.py +++ b/disnake/user.py @@ -281,7 +281,7 @@ def mentioned_in(self, message: Message) -> bool: class ClientUser(BaseUser): """Represents your Discord user. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -419,7 +419,7 @@ async def edit( class User(BaseUser, disnake.abc.Messageable): """Represents a Discord user. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/voice_region.py b/disnake/voice_region.py index 6c957d5e04..b08689db5b 100644 --- a/disnake/voice_region.py +++ b/disnake/voice_region.py @@ -14,7 +14,7 @@ class VoiceRegion: """Represents a Discord voice region. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/webhook/async_.py b/disnake/webhook/async_.py index edd9ec3dcd..4bd6f4b2d1 100644 --- a/disnake/webhook/async_.py +++ b/disnake/webhook/async_.py @@ -1034,7 +1034,7 @@ async def foo(): For a synchronous counterpart, see :class:`SyncWebhook`. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/webhook/sync.py b/disnake/webhook/sync.py index 0d2ba42b6c..b1debb9cf3 100644 --- a/disnake/webhook/sync.py +++ b/disnake/webhook/sync.py @@ -510,7 +510,7 @@ class SyncWebhook(BaseWebhook): For an asynchronous counterpart, see :class:`Webhook`. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/disnake/widget.py b/disnake/widget.py index d5056d82dc..4293985a36 100644 --- a/disnake/widget.py +++ b/disnake/widget.py @@ -34,7 +34,7 @@ class WidgetChannel: """Represents a "partial" widget channel. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -89,7 +89,7 @@ def created_at(self) -> datetime.datetime: class WidgetMember(BaseUser): """Represents a "partial" member of the widget's guild. - .. container:: operations + .. collapse:: operations .. describe:: x == y @@ -262,7 +262,7 @@ async def edit( class Widget: """Represents a :class:`Guild` widget. - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/docs/_static/style.css b/docs/_static/style.css index a54a5ecf52..b89e43a930 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1313,21 +1313,20 @@ rect.highlighted { fill: var(--highlighted-text); } -.container.operations { +details.operations { padding: 10px; border: 1px solid var(--codeblock-border); margin-bottom: 20px; } -.container.operations::before { - content: 'Supported Operations'; - color: var(--main-big-headers-text); - display: block; - padding-bottom: 0.5em; +details.operations dl { + margin-top: 15px; + margin-bottom: 15px; } -.container.operations > dl.describe > dt { - background-color: var(--api-entry-background); +details.operations > summary::after { + content: 'Supported Operations'; + color: var(--main-big-headers-text); } .table-wrapper { diff --git a/docs/api/audit_logs.rst b/docs/api/audit_logs.rst index 1052c610b4..e4cb0573d2 100644 --- a/docs/api/audit_logs.rst +++ b/docs/api/audit_logs.rst @@ -93,7 +93,7 @@ AuditLogDiff on the action being done, check the documentation for :class:`AuditLogAction`, otherwise check the documentation below for all attributes that are possible. - .. container:: operations + .. collapse:: operations .. describe:: iter(diff) diff --git a/docs/api/guilds.rst b/docs/api/guilds.rst index 85f8d3d494..72a7ad6f84 100644 --- a/docs/api/guilds.rst +++ b/docs/api/guilds.rst @@ -131,7 +131,7 @@ VerificationLevel Specifies a :class:`Guild`\'s verification level, which is the criteria in which a member must meet before being able to send messages to the guild. - .. container:: operations + .. collapse:: operations .. versionadded:: 2.0 @@ -180,9 +180,7 @@ NotificationLevel Specifies whether a :class:`Guild` has notifications on for all messages or mentions only by default. - .. container:: operations - - .. versionadded:: 2.0 + .. collapse:: operations .. describe:: x == y @@ -219,9 +217,7 @@ ContentFilter learning algorithms that Discord uses to detect if an image contains NSFW content. - .. container:: operations - - .. versionadded:: 2.0 + .. collapse:: operations .. describe:: x == y @@ -261,7 +257,7 @@ NSFWLevel .. versionadded:: 2.0 - .. container:: operations + .. collapse:: operations .. describe:: x == y diff --git a/docs/api/messages.rst b/docs/api/messages.rst index 123260f167..3031d955d9 100644 --- a/docs/api/messages.rst +++ b/docs/api/messages.rst @@ -188,14 +188,14 @@ MessageType Specifies the type of :class:`Message`. This is used to denote if a message is to be interpreted as a system message or a regular message. - .. container:: operations + .. collapse:: operations - .. describe:: x == y + .. describe:: x == y - Checks if two messages are equal. - .. describe:: x != y + Checks if two messages are equal. + .. describe:: x != y - Checks if two messages are not equal. + Checks if two messages are not equal. .. attribute:: default diff --git a/docs/api/misc.rst b/docs/api/misc.rst index c49854c289..83ee3298d8 100644 --- a/docs/api/misc.rst +++ b/docs/api/misc.rst @@ -18,7 +18,7 @@ AsyncIterator Represents the "AsyncIterator" concept. Note that no such class exists, it is purely abstract. - .. container:: operations + .. collapse:: operations .. describe:: async for x in y diff --git a/docs/conf.py b/docs/conf.py index 355f977465..5944191079 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -52,6 +52,7 @@ "exception_hierarchy", "attributetable", "resourcelinks", + "collapse", "nitpick_file_ignorer", ] diff --git a/docs/extensions/collapse.py b/docs/extensions/collapse.py new file mode 100644 index 0000000000..cde568aea2 --- /dev/null +++ b/docs/extensions/collapse.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: MIT +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from docutils import nodes +from docutils.parsers.rst import Directive, directives + +if TYPE_CHECKING: + from sphinx.application import Sphinx + from sphinx.util.typing import OptionSpec + from sphinx.writers.html import HTMLTranslator + + from ._types import SphinxExtensionMeta + + +class collapse(nodes.General, nodes.Element): + pass + + +def visit_collapse_node(self: HTMLTranslator, node: nodes.Element) -> None: + attrs = {"open": ""} if node["open"] else {} + self.body.append(self.starttag(node, "details", **attrs)) + self.body.append("") + + +def depart_collapse_node(self: HTMLTranslator, node: nodes.Element) -> None: + self.body.append("\n") + + +class CollapseDirective(Directive): + has_content = True + + optional_arguments = 1 + final_argument_whitespace = True + + option_spec: ClassVar[OptionSpec] = {"open": directives.flag} + + def run(self): + self.assert_has_content() + node = collapse( + "\n".join(self.content), + open="open" in self.options, + ) + + classes = directives.class_option(self.arguments[0] if self.arguments else "") + node["classes"].extend(classes) + + self.state.nested_parse(self.content, self.content_offset, node) + return [node] + + +def setup(app: Sphinx) -> SphinxExtensionMeta: + app.add_node(collapse, html=(visit_collapse_node, depart_collapse_node)) + app.add_directive("collapse", CollapseDirective) + + return { + "parallel_read_safe": True, + "parallel_write_safe": True, + } From b23786bd7064e4e11e1657b1f70f6a83480ec19d Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 26 Oct 2023 19:22:52 +0200 Subject: [PATCH 5/6] fix(commands): handle interactions in union types correctly (#1121) --- changelog/1121.feature.rst | 1 + disnake/ext/commands/params.py | 38 ++++++++++++------------ tests/ext/commands/test_params.py | 49 ++++++++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 19 deletions(-) create mode 100644 changelog/1121.feature.rst diff --git a/changelog/1121.feature.rst b/changelog/1121.feature.rst new file mode 100644 index 0000000000..1294ba4044 --- /dev/null +++ b/changelog/1121.feature.rst @@ -0,0 +1 @@ +Make :class:`Interaction` and subtypes accept the bot type as a generic parameter to denote the type returned by the :attr:`~Interaction.bot` and :attr:`~Interaction.client` properties. diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 2ab93359d2..0e702385ad 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -31,7 +31,6 @@ Type, TypeVar, Union, - get_args, get_origin, get_type_hints, ) @@ -110,17 +109,26 @@ def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[TypeT]: + """Similar to the builtin `issubclass`, but more lenient. + Can also handle unions (`issubclass(Union[int, str], int)`) and + generic types (`issubclass(X[T], X)`) in the first argument. + """ if not isinstance(tp, (type, tuple)): return False - elif not isinstance(obj, type): - # Assume we have a type hint - if get_origin(obj) in (Union, UnionType, Optional): - obj = get_args(obj) - return any(isinstance(o, type) and issubclass(o, tp) for o in obj) - else: - # Other type hint specializations are not supported - return False - return issubclass(obj, tp) + elif isinstance(obj, type): + # common case + return issubclass(obj, tp) + + # At this point, `obj` is likely a generic type hint + if (origin := get_origin(obj)) is None: + return False + + if origin in (Union, UnionType): + # If we have a Union, try matching any of its args + # (recursively, to handle possibly generic types inside this union) + return any(issubclass_(o, tp) for o in obj.__args__) + else: + return isinstance(origin, type) and issubclass(origin, tp) def remove_optionals(annotation: Any) -> Any: @@ -912,7 +920,6 @@ def isolate_self( parametersl.pop(0) if parametersl: annot = parametersl[0].annotation - annot = get_origin(annot) or annot if issubclass_(annot, ApplicationCommandInteraction) or annot is inspect.Parameter.empty: inter_param = parameters.pop(parametersl[0].name) @@ -984,9 +991,7 @@ def collect_params( injections[parameter.name] = default elif parameter.annotation in Injection._registered: injections[parameter.name] = Injection._registered[parameter.annotation] - elif issubclass_( - get_origin(parameter.annotation) or parameter.annotation, ApplicationCommandInteraction - ): + elif issubclass_(parameter.annotation, ApplicationCommandInteraction): if inter_param is None: inter_param = parameter else: @@ -1120,10 +1125,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]: if param.autocomplete: command.autocompleters[param.name] = param.autocomplete - if issubclass_( - get_origin(annot := sig.parameters[inter_param].annotation) or annot, - disnake.GuildCommandInteraction, - ): + if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction): command._guild_only = True return [param.to_option() for param in params] diff --git a/tests/ext/commands/test_params.py b/tests/ext/commands/test_params.py index a3b4ea4289..8e8ca91304 100644 --- a/tests/ext/commands/test_params.py +++ b/tests/ext/commands/test_params.py @@ -10,6 +10,7 @@ import disnake from disnake import Member, Role, User from disnake.ext import commands +from disnake.ext.commands import params OptionType = disnake.OptionType @@ -66,6 +67,53 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None: with pytest.raises(commands.errors.MemberNotFound): await info.verify_type(mock.Mock(), arg_mock) + def test_isolate_self(self) -> None: + def func(a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_inter(self) -> None: + def func(i: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_cog_inter(self) -> None: + def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is not None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_generic(self) -> None: + def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_union(self) -> None: + def func( + i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int + ) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + # this uses `Range` for testing `_BaseRange`, `String` should work equally class TestBaseRange: @@ -189,7 +237,6 @@ def test_string(self) -> None: assert info.max_value is None assert info.type == annotation.underlying_type - # uses lambdas since new union syntax isn't supported on all versions @pytest.mark.parametrize( "annotation_str", [ From f2e5886c1d103f789fd0f8e44f6b1279e7e9ffa5 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 26 Oct 2023 19:42:05 +0200 Subject: [PATCH 6/6] refactor: unify slash/prefix command signature evaluation (#1116) --- changelog/1116.misc.rst | 1 + disnake/ext/commands/core.py | 49 ++++-------------- disnake/ext/commands/params.py | 83 +++++++++++-------------------- disnake/utils.py | 67 +++++++++++++++++++++++++ tests/ext/commands/test_params.py | 10 ++-- 5 files changed, 113 insertions(+), 97 deletions(-) create mode 100644 changelog/1116.misc.rst diff --git a/changelog/1116.misc.rst b/changelog/1116.misc.rst new file mode 100644 index 0000000000..7e17a486ef --- /dev/null +++ b/changelog/1116.misc.rst @@ -0,0 +1 @@ +|commands| Rewrite slash command signature evaluation to use the same mechanism as prefix command signatures. This should not have an impact on user code, but streamlines future changes. diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index 2d7ff5497e..669dd60e04 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -27,7 +27,12 @@ ) import disnake -from disnake.utils import _generated, _overload_with_permissions +from disnake.utils import ( + _generated, + _overload_with_permissions, + get_signature_parameters, + unwrap_function, +) from ._types import _BaseCommand from .cog import Cog @@ -114,42 +119,6 @@ P = TypeVar("P") -def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: - partial = functools.partial - while True: - if hasattr(function, "__wrapped__"): - function = function.__wrapped__ - elif isinstance(function, partial): - function = function.func - else: - return function - - -def get_signature_parameters( - function: Callable[..., Any], globalns: Dict[str, Any] -) -> Dict[str, inspect.Parameter]: - signature = inspect.signature(function) - params = {} - cache: Dict[str, Any] = {} - eval_annotation = disnake.utils.evaluate_annotation - for name, parameter in signature.parameters.items(): - annotation = parameter.annotation - if annotation is parameter.empty: - params[name] = parameter - continue - if annotation is None: - params[name] = parameter.replace(annotation=type(None)) - continue - - annotation = eval_annotation(annotation, globalns, globalns, cache) - if annotation is Greedy: - raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") - - params[name] = parameter.replace(annotation=annotation) - - return params - - def wrap_callback(coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -410,7 +379,11 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None: except AttributeError: globalns = {} - self.params = get_signature_parameters(function, globalns) + params = get_signature_parameters(function, globalns) + for param in params.values(): + if param.annotation is Greedy: + raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") + self.params = params def add_check(self, func: Check) -> None: """Adds a check to the command. diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 0e702385ad..5aae2de611 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -10,6 +10,7 @@ import itertools import math import sys +import types from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta @@ -32,7 +33,6 @@ TypeVar, Union, get_origin, - get_type_hints, ) import disnake @@ -42,7 +42,7 @@ from disnake.ext import commands from disnake.i18n import Localized from disnake.interactions import ApplicationCommandInteraction -from disnake.utils import maybe_coroutine +from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine from . import errors from .converter import CONVERTER_MAPPING @@ -143,37 +143,6 @@ def remove_optionals(annotation: Any) -> Any: return annotation -def signature(func: Callable) -> inspect.Signature: - """Get the signature with evaluated annotations wherever possible - - This is equivalent to `signature(..., eval_str=True)` in python 3.10 - """ - if sys.version_info >= (3, 10): - return inspect.signature(func, eval_str=True) - - if inspect.isfunction(func) or inspect.ismethod(func): - typehints = get_type_hints(func) - else: - typehints = get_type_hints(func.__call__) - - signature = inspect.signature(func) - parameters = [] - - for name, param in signature.parameters.items(): - if isinstance(param.annotation, str): - param = param.replace(annotation=typehints.get(name, inspect.Parameter.empty)) - if param.annotation is type(None): - param = param.replace(annotation=None) - - parameters.append(param) - - return_annotation = typehints.get("return", inspect.Parameter.empty) - if return_annotation is type(None): - return_annotation = None - - return signature.replace(parameters=parameters, return_annotation=return_annotation) - - def _xt_to_xe(xe: Optional[float], xt: Optional[float], direction: float = 1) -> Optional[float]: """Function for combining xt and xe @@ -795,7 +764,14 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo return True def parse_converter_annotation(self, converter: Callable, fallback_annotation: Any) -> None: - _, parameters = isolate_self(signature(converter)) + if isinstance(converter, (types.FunctionType, types.MethodType)): + converter_func = converter + else: + # if converter isn't a function/method, assume it's a callable object/type + # (we need `__call__` here to get the correct global namespace later, since + # classes do not have `__globals__`) + converter_func = converter.__call__ + _, parameters = isolate_self(get_signature_parameters(converter_func)) if len(parameters) != 1: raise TypeError( @@ -858,9 +834,9 @@ def to_option(self) -> Option: def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwargs: Any) -> T: """Calls a function without providing any extra unexpected arguments""" MISSING: Any = object() - sig = signature(function) + parameters = get_signature_parameters(function) - kinds = {p.kind for p in sig.parameters.values()} + kinds = {p.kind for p in parameters.values()} arb = {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} if arb.issubset(kinds): raise TypeError( @@ -874,7 +850,7 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa for index, parameter, posarg in itertools.zip_longest( itertools.count(), - sig.parameters.values(), + parameters.values(), possible_args, fillvalue=MISSING, ): @@ -903,15 +879,15 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa def isolate_self( - sig: inspect.Signature, + parameters: Dict[str, inspect.Parameter], ) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]: """Create parameters without self and the first interaction""" - parameters = dict(sig.parameters) - parametersl = list(sig.parameters.values()) - if not parameters: return (None, None), {} + parameters = dict(parameters) # shallow copy + parametersl = list(parameters.values()) + cog_param: Optional[inspect.Parameter] = None inter_param: Optional[inspect.Parameter] = None @@ -961,19 +937,19 @@ def classify_autocompleter(autocompleter: AnyAutocompleter) -> None: def collect_params( function: Callable, - sig: Optional[inspect.Signature] = None, + parameters: Optional[Dict[str, inspect.Parameter]] = None, ) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]: """Collect all parameters in a function. - Optionally accepts an `inspect.Signature` object (as an optimization), - calls `signature(function)` if not provided. + Optionally accepts a `{str: inspect.Parameter}` dict as an optimization, + calls `get_signature_parameters(function)` if not provided. Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`) """ - if sig is None: - sig = signature(function) + if parameters is None: + parameters = get_signature_parameters(function) - (cog_param, inter_param), parameters = isolate_self(sig) + (cog_param, inter_param), parameters = isolate_self(parameters) doc = disnake.utils.parse_docstring(function)["params"] @@ -1097,10 +1073,10 @@ def expand_params(command: AnySlashCommand) -> List[Option]: Returns the created options """ - sig = signature(command.callback) - # pass `sig` down to avoid having to call `signature(func)` another time, + parameters = get_signature_parameters(command.callback) + # pass `parameters` down to avoid having to call `get_signature_parameters(func)` another time, # which may cause side effects with deferred annotations and warnings - _, inter_param, params, injections = collect_params(command.callback, sig) + _, inter_param, params, injections = collect_params(command.callback, parameters) if inter_param is None: raise TypeError(f"Couldn't find an interaction parameter in {command.callback}") @@ -1125,7 +1101,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]: if param.autocomplete: command.autocompleters[param.name] = param.autocomplete - if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction): + if issubclass_(parameters[inter_param].annotation, disnake.GuildCommandInteraction): command._guild_only = True return [param.to_option() for param in params] @@ -1407,12 +1383,11 @@ def register_injection( :class:`Injection` The injection being registered. """ - sig = signature(function) - tp = sig.return_annotation + tp = get_signature_return(function) if tp is inspect.Parameter.empty: raise TypeError("Injection must have a return annotation") if tp in ParamInfo.TYPES: raise TypeError("Injection cannot overwrite builtin types") - return Injection.register(function, sig.return_annotation, autocompleters=autocompleters) + return Injection.register(function, tp, autocompleters=autocompleters) diff --git a/disnake/utils.py b/disnake/utils.py index 1d06f137d2..d40cd4e8fe 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -6,6 +6,7 @@ import asyncio import datetime import functools +import inspect import json import os import pkgutil @@ -1203,6 +1204,72 @@ def resolve_annotation( return evaluate_annotation(annotation, globalns, locals, cache) +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, "__wrapped__"): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + + +def _get_function_globals(function: Callable[..., Any]) -> Dict[str, Any]: + unwrap = unwrap_function(function) + try: + return unwrap.__globals__ + except AttributeError: + return {} + + +_inspect_empty = inspect.Parameter.empty + + +def get_signature_parameters( + function: Callable[..., Any], globalns: Optional[Dict[str, Any]] = None +) -> Dict[str, inspect.Parameter]: + # if no globalns provided, unwrap (where needed) and get global namespace from there + if globalns is None: + globalns = _get_function_globals(function) + + params: Dict[str, inspect.Parameter] = {} + cache: Dict[str, Any] = {} + + signature = inspect.signature(function) + + # eval all parameter annotations + for name, parameter in signature.parameters.items(): + annotation = parameter.annotation + if annotation is _inspect_empty: + params[name] = parameter + continue + + if annotation is None: + annotation = type(None) + else: + annotation = evaluate_annotation(annotation, globalns, globalns, cache) + + params[name] = parameter.replace(annotation=annotation) + + return params + + +def get_signature_return(function: Callable[..., Any]) -> Any: + signature = inspect.signature(function) + + # same as parameters above, but for the return annotation + ret = signature.return_annotation + if ret is not _inspect_empty: + if ret is None: + ret = type(None) + else: + globalns = _get_function_globals(function) + ret = evaluate_annotation(ret, globalns, globalns, {}) + + return ret + + TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] diff --git a/tests/ext/commands/test_params.py b/tests/ext/commands/test_params.py index 8e8ca91304..61c812c8f0 100644 --- a/tests/ext/commands/test_params.py +++ b/tests/ext/commands/test_params.py @@ -71,7 +71,7 @@ def test_isolate_self(self) -> None: def func(a: int) -> None: ... - (cog, inter), parameters = params.isolate_self(params.signature(func)) + (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) assert cog is None assert inter is None assert parameters == ({"a": mock.ANY}) @@ -80,7 +80,7 @@ def test_isolate_self_inter(self) -> None: def func(i: disnake.ApplicationCommandInteraction, a: int) -> None: ... - (cog, inter), parameters = params.isolate_self(params.signature(func)) + (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) assert cog is None assert inter is not None assert parameters == ({"a": mock.ANY}) @@ -89,7 +89,7 @@ def test_isolate_self_cog_inter(self) -> None: def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None: ... - (cog, inter), parameters = params.isolate_self(params.signature(func)) + (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) assert cog is not None assert inter is not None assert parameters == ({"a": mock.ANY}) @@ -98,7 +98,7 @@ def test_isolate_self_generic(self) -> None: def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: ... - (cog, inter), parameters = params.isolate_self(params.signature(func)) + (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) assert cog is None assert inter is not None assert parameters == ({"a": mock.ANY}) @@ -109,7 +109,7 @@ def func( ) -> None: ... - (cog, inter), parameters = params.isolate_self(params.signature(func)) + (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) assert cog is None assert inter is not None assert parameters == ({"a": mock.ANY})