From aa97387f6f78494c50933c70a453a8b8ac8dc01a Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 9 Oct 2023 15:57:40 +0200 Subject: [PATCH 01/11] feat: handle `TypeAliasType` in annotations --- disnake/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/disnake/utils.py b/disnake/utils.py index d40cd4e8fe..aed8e2fa84 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1144,6 +1144,13 @@ def evaluate_annotation( cache[tp] = evaluated return evaluated + # TypeAliasType, 3.12+ + if hasattr(tp, "__value__"): + # accessing `__value__` automatically evaluates the type alias in the annotation scope; + # recurse to resolve possible forwardrefs + return evaluate_annotation(tp.__value__, globals, locals, cache) + + # GenericAlias if hasattr(tp, "__args__"): implicit_str = True is_literal = False From dc5a5fb71a49efccefefeef8ba8ee642b4d361ec Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 9 Oct 2023 15:58:04 +0200 Subject: [PATCH 02/11] test: add test for new functionality --- tests/test_utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index a8f52e6b1f..fcbd31eb5a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ import warnings from dataclasses import dataclass from datetime import timedelta, timezone -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from unittest import mock import pytest @@ -19,6 +19,12 @@ from . import helpers +if TYPE_CHECKING: + from typing_extensions import TypeAliasType +elif sys.version_info >= (3, 12): + # non-3.12 tests shouldn't be using this + from typing import TypeAliasType + def test_missing() -> None: assert utils.MISSING != utils.MISSING @@ -784,6 +790,13 @@ def test_resolve_annotation_literal() -> None: utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {}) # type: ignore +@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12") +def test_resolve_annotation_typealiastype() -> None: + # this is equivalent to `type CoolList = List['int']` + CoolList = TypeAliasType("CoolList", List["int"]) + assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int] + + @pytest.mark.parametrize( ("dt", "style", "expected"), [ From cd57c65dd29be0921ea6e101af641d8f83840b7d Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 13 Oct 2023 15:39:47 +0200 Subject: [PATCH 03/11] fix: use correct namespace for recursing on TypeAliasTypes --- disnake/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index aed8e2fa84..bb138eb394 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1146,8 +1146,11 @@ def evaluate_annotation( # TypeAliasType, 3.12+ if hasattr(tp, "__value__"): - # accessing `__value__` automatically evaluates the type alias in the annotation scope; - # recurse to resolve possible forwardrefs + # Use __module__ to get the namespace in which the type alias was defined. + if mod := sys.modules.get(tp.__module__): + globals = locals = mod.__dict__ + # Accessing `__value__` automatically evaluates the type alias in the annotation scope. + # (recurse to resolve possible forwardrefs) return evaluate_annotation(tp.__value__, globals, locals, cache) # GenericAlias From 8c5f3b8bcbf50f000cbcb15dae66b61f73ea2973 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sat, 14 Oct 2023 22:10:03 +0200 Subject: [PATCH 04/11] fix: check `GenericAlias` before `TypeAliasType` Reason for this is that GenericAlias proxies its `__origin__`'s attributes; if we have a GenericAlias with a TypeAliasType `__origin__`, `hasattr(tp, "__value__")` will match even though we're really looking at a GenericAlias. --- disnake/utils.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index bb138eb394..efe11e2b28 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1119,6 +1119,17 @@ def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: return tuple(p for p in parameters if p is not none_cls) + (none_cls,) +def _resolve_typealiastype( + tp: Any, globals: Dict[str, Any], locals: Dict[str, Any], cache: Dict[str, Any] +): + # Use __module__ to get the namespace in which the type alias was defined. + if mod := sys.modules.get(tp.__module__): + globals = locals = mod.__dict__ + # Accessing `__value__` automatically evaluates the type alias in the annotation scope. + # (recurse to resolve possible forwardrefs, aliases, etc.) + return evaluate_annotation(tp.__value__, globals, locals, cache) + + def evaluate_annotation( tp: Any, globals: Dict[str, Any], @@ -1144,16 +1155,7 @@ def evaluate_annotation( cache[tp] = evaluated return evaluated - # TypeAliasType, 3.12+ - if hasattr(tp, "__value__"): - # Use __module__ to get the namespace in which the type alias was defined. - if mod := sys.modules.get(tp.__module__): - globals = locals = mod.__dict__ - # Accessing `__value__` automatically evaluates the type alias in the annotation scope. - # (recurse to resolve possible forwardrefs) - return evaluate_annotation(tp.__value__, globals, locals, cache) - - # GenericAlias + # GenericAlias / UnionType if hasattr(tp, "__args__"): implicit_str = True is_literal = False @@ -1194,6 +1196,10 @@ def evaluate_annotation( except AttributeError: return tp.__origin__[evaluated_args] + # TypeAliasType, 3.12+ + if hasattr(tp, "__value__"): + return _resolve_typealiastype(tp, globals, locals, cache) + return tp From 1e4e701f60d4066f8c7c785de8c0f33bf1aafe08 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sat, 14 Oct 2023 22:14:55 +0200 Subject: [PATCH 05/11] feat: handle TypeAliasType origins --- disnake/utils.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index efe11e2b28..0550ed23ec 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1157,22 +1157,29 @@ def evaluate_annotation( # GenericAlias / UnionType if hasattr(tp, "__args__"): - implicit_str = True - is_literal = False - orig_args = args = tp.__args__ if not hasattr(tp, "__origin__"): if tp.__class__ is UnionType: converted = Union[args] # type: ignore return evaluate_annotation(converted, globals, locals, cache) return tp - if tp.__origin__ is Union: + + implicit_str = True + is_literal = False + orig_args = args = tp.__args__ + orig_origin = origin = tp.__origin__ + + # origin can be a TypeAliasType too, resolve it and continue + if hasattr(origin, "__value__"): + origin = _resolve_typealiastype(origin, globals, locals, cache) + + if origin is Union: try: if args.index(type(None)) != len(args) - 1: args = normalise_optional_params(tp.__args__) except ValueError: pass - if tp.__origin__ is Literal: + if origin is Literal: if not PY_310: args = flatten_literal_params(tp.__args__) implicit_str = False @@ -1188,13 +1195,17 @@ def evaluate_annotation( ): raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + if origin != orig_origin: + # we can't use `copy_with` in this case, so just skip all of the following logic + return origin[evaluated_args] + if evaluated_args == orig_args: return tp try: return tp.copy_with(evaluated_args) except AttributeError: - return tp.__origin__[evaluated_args] + return origin[evaluated_args] # TypeAliasType, 3.12+ if hasattr(tp, "__value__"): From ac7ebafe04484f381ab4b889f38d7ecefd82a820 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sat, 14 Oct 2023 22:15:17 +0200 Subject: [PATCH 06/11] fix: please send help --- disnake/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disnake/utils.py b/disnake/utils.py index 0550ed23ec..37ab722890 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1159,7 +1159,7 @@ def evaluate_annotation( if hasattr(tp, "__args__"): if not hasattr(tp, "__origin__"): if tp.__class__ is UnionType: - converted = Union[args] # type: ignore + converted = Union[tp.__args__] # type: ignore return evaluate_annotation(converted, globals, locals, cache) return tp From 233c954e85cd252f17565121b0bb7d5ffd9db78c Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sat, 14 Oct 2023 22:15:50 +0200 Subject: [PATCH 07/11] test: more complicated tests --- tests/test_utils.py | 44 +++++++++++++++++++++++++++++++----- tests/utils_helper_module.py | 23 +++++++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) create mode 100644 tests/utils_helper_module.py diff --git a/tests/test_utils.py b/tests/test_utils.py index fcbd31eb5a..a96f18ed54 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ import warnings from dataclasses import dataclass from datetime import timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union from unittest import mock import pytest @@ -17,7 +17,7 @@ import disnake from disnake import utils -from . import helpers +from . import helpers, utils_helper_module if TYPE_CHECKING: from typing_extensions import TypeAliasType @@ -791,10 +791,42 @@ def test_resolve_annotation_literal() -> None: @pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12") -def test_resolve_annotation_typealiastype() -> None: - # this is equivalent to `type CoolList = List['int']` - CoolList = TypeAliasType("CoolList", List["int"]) - assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int] +class TestResolveAnnotationTypeAliasType: + def test_simple(self) -> None: + # this is equivalent to `type CoolList = List[int]` + CoolList = TypeAliasType("CoolList", List[int]) + assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int] + + def test_generic(self) -> None: + # this is equivalent to `type CoolList[T] = List[T]; CoolList[int]` + T = TypeVar("T") + CoolList = TypeAliasType("CoolList", List[T], type_params=(T,)) + + annotation = CoolList[int] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int] + + # alias and arg in local scope + def test_forwardref_local(self) -> None: + T = TypeVar("T") + IntOrStr = Union[int, str] + CoolList = TypeAliasType("CoolList", List[T], type_params=(T,)) + + annotation = CoolList["IntOrStr"] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr] + + # alias and arg in other module scope + def test_forwardref_module(self) -> None: + resolved = utils.resolve_annotation( + utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {} + ) + assert resolved == List[Union[int, str]] + + # combination of the previous two, alias in other module scope and arg in local scope + def test_forwardref_mixed(self) -> None: + LocalIntOrStr = Union[int, str] + + annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr] @pytest.mark.parametrize( diff --git a/tests/utils_helper_module.py b/tests/utils_helper_module.py new file mode 100644 index 0000000000..074e844b32 --- /dev/null +++ b/tests/utils_helper_module.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: MIT + +"""Separate module file for some test_utils.py type annotation tests.""" + +import sys +from typing import TYPE_CHECKING, List, TypeVar, Union + +version = sys.version_info # assign to variable to trick pyright + +if TYPE_CHECKING: + from typing_extensions import TypeAliasType +elif version >= (3, 12): + # non-3.12 tests shouldn't be using this + from typing import TypeAliasType + +if version >= (3, 12): + CoolUniqueIntOrStrAlias = Union[int, str] + ListWithForwardRefAlias = TypeAliasType( + "ListWithForwardRefAlias", List["CoolUniqueIntOrStrAlias"] + ) + + T = TypeVar("T") + GenericListAlias = TypeAliasType("GenericListAlias", List[T], type_params=(T,)) From 131adbf14de20173017dc1c85e94239946bc78d8 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 26 Oct 2023 19:52:35 +0200 Subject: [PATCH 08/11] docs: add changelog entry --- changelog/1128.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1128.feature.rst diff --git a/changelog/1128.feature.rst b/changelog/1128.feature.rst new file mode 100644 index 0000000000..66c35b1935 --- /dev/null +++ b/changelog/1128.feature.rst @@ -0,0 +1 @@ +|commands| Support Python 3.12's ``type`` statement and :class:`py:typing.TypeAliasType` annotations in command signatures. From d9a36c8ae321b59d20d6622b8efa2825f2621be2 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 26 Oct 2023 21:18:51 +0200 Subject: [PATCH 09/11] test: add regression test for dropping cache on namespace change --- tests/test_utils.py | 20 ++++++++++++++++++++ tests/utils_helper_module.py | 3 +++ 2 files changed, 23 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index a96f18ed54..7a590533b8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -828,6 +828,26 @@ def test_forwardref_mixed(self) -> None: annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"] assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr] + # two different forwardrefs with same name + def test_forwardref_duplicate(self) -> None: + DuplicateAlias = int + + # first, resolve an annotation where `DuplicateAlias` resolves to the local int + cache = {} + assert ( + utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache) + == List[int] + ) + + # then, resolve an annotation where the globalns changes and `DuplicateAlias` resolves to something else + # (i.e. this should not resolve to `List[int]` despite {"DuplicateAlias": int} in the cache) + assert ( + utils.resolve_annotation( + utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache + ) + == List[str] + ) + @pytest.mark.parametrize( ("dt", "style", "expected"), diff --git a/tests/utils_helper_module.py b/tests/utils_helper_module.py index 074e844b32..7711e861b8 100644 --- a/tests/utils_helper_module.py +++ b/tests/utils_helper_module.py @@ -21,3 +21,6 @@ T = TypeVar("T") GenericListAlias = TypeAliasType("GenericListAlias", List[T], type_params=(T,)) + + DuplicateAlias = str + ListWithDuplicateAlias = TypeAliasType("ListWithDuplicateAlias", List["DuplicateAlias"]) From 440c6a48538ebb6559cd740e4bce46721ba3fa1e Mon Sep 17 00:00:00 2001 From: shiftinv Date: Thu, 26 Oct 2023 21:19:07 +0200 Subject: [PATCH 10/11] fix: drop annotation cache when moving to different globalns while resolving TypeAliasType see previous commit for tests --- disnake/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index 37ab722890..56da106ed2 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1122,9 +1122,15 @@ def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: def _resolve_typealiastype( tp: Any, globals: Dict[str, Any], locals: Dict[str, Any], cache: Dict[str, Any] ): - # Use __module__ to get the namespace in which the type alias was defined. + # Use __module__ to get the (global) namespace in which the type alias was defined. if mod := sys.modules.get(tp.__module__): - globals = locals = mod.__dict__ + mod_globals = mod.__dict__ + if mod_globals is not globals or mod_globals is not locals: + # if the namespace changed (usually when a TypeAliasType was imported from a different module), + # drop the cache since names can resolve differently now + cache = {} + globals = locals = mod_globals + # Accessing `__value__` automatically evaluates the type alias in the annotation scope. # (recurse to resolve possible forwardrefs, aliases, etc.) return evaluate_annotation(tp.__value__, globals, locals, cache) From e64ead51073b39eaf5b9ee6620f7ea1b7305f94b Mon Sep 17 00:00:00 2001 From: shiftinv Date: Wed, 22 Nov 2023 23:58:32 +0100 Subject: [PATCH 11/11] chore: add fixme for future reference --- disnake/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/disnake/utils.py b/disnake/utils.py index 2bb9ea43eb..9061cd0f61 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1139,6 +1139,7 @@ def _resolve_typealiastype( return evaluate_annotation(tp.__value__, globals, locals, cache) +# FIXME: this should be split up into smaller functions for clarity and easier maintenance def evaluate_annotation( tp: Any, globals: Dict[str, Any],