Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(commands): support 3.12's type statement and TypeAliasType #1128

Merged
merged 13 commits into from
Dec 8, 2023
Merged
1 change: 1 addition & 0 deletions changelog/1128.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Support Python 3.12's ``type`` statement and :class:`py:typing.TypeAliasType` annotations in command signatures.
48 changes: 41 additions & 7 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,24 @@ def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)


def _resolve_typealiastype(
tp: Any, globals: Dict[str, Any], locals: Dict[str, Any], cache: Dict[str, Any]
):
# Use __module__ to get the (global) namespace in which the type alias was defined.
if mod := sys.modules.get(tp.__module__):
mod_globals = mod.__dict__
if mod_globals is not globals or mod_globals is not locals:
# if the namespace changed (usually when a TypeAliasType was imported from a different module),
# drop the cache since names can resolve differently now
cache = {}
globals = locals = mod_globals

# Accessing `__value__` automatically evaluates the type alias in the annotation scope.
# (recurse to resolve possible forwardrefs, aliases, etc.)
return evaluate_annotation(tp.__value__, globals, locals, cache)


# FIXME: this should be split up into smaller functions for clarity and easier maintenance
def evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
Expand All @@ -1147,23 +1165,31 @@ def evaluate_annotation(
cache[tp] = evaluated
return evaluated

# GenericAlias / UnionType
if hasattr(tp, "__args__"):
implicit_str = True
is_literal = False
orig_args = args = tp.__args__
if not hasattr(tp, "__origin__"):
if tp.__class__ is UnionType:
converted = Union[args] # type: ignore
converted = Union[tp.__args__] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)

return tp
if tp.__origin__ is Union:

implicit_str = True
is_literal = False
orig_args = args = tp.__args__
orig_origin = origin = tp.__origin__

# origin can be a TypeAliasType too, resolve it and continue
if hasattr(origin, "__value__"):
origin = _resolve_typealiastype(origin, globals, locals, cache)

if origin is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if origin is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
Expand All @@ -1179,13 +1205,21 @@ def evaluate_annotation(
):
raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")

if origin != orig_origin:
# we can't use `copy_with` in this case, so just skip all of the following logic
return origin[evaluated_args]

if evaluated_args == orig_args:
return tp

try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return origin[evaluated_args]

# TypeAliasType, 3.12+
if hasattr(tp, "__value__"):
return _resolve_typealiastype(tp, globals, locals, cache)

return tp

Expand Down
69 changes: 67 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from dataclasses import dataclass
from datetime import timedelta, timezone
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union
from unittest import mock

import pytest
Expand All @@ -18,7 +18,13 @@
import disnake
from disnake import utils

from . import helpers
from . import helpers, utils_helper_module

if TYPE_CHECKING:
from typing_extensions import TypeAliasType
elif sys.version_info >= (3, 12):
# non-3.12 tests shouldn't be using this
from typing import TypeAliasType


def test_missing() -> None:
Expand Down Expand Up @@ -785,6 +791,65 @@ def test_resolve_annotation_literal() -> None:
utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {}) # type: ignore


@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12")
class TestResolveAnnotationTypeAliasType:
def test_simple(self) -> None:
# this is equivalent to `type CoolList = List[int]`
CoolList = TypeAliasType("CoolList", List[int])
assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int]

def test_generic(self) -> None:
# this is equivalent to `type CoolList[T] = List[T]; CoolList[int]`
T = TypeVar("T")
CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))

annotation = CoolList[int]
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int]

# alias and arg in local scope
def test_forwardref_local(self) -> None:
T = TypeVar("T")
IntOrStr = Union[int, str]
CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))

annotation = CoolList["IntOrStr"]
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr]

# alias and arg in other module scope
def test_forwardref_module(self) -> None:
resolved = utils.resolve_annotation(
utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {}
)
assert resolved == List[Union[int, str]]

# combination of the previous two, alias in other module scope and arg in local scope
def test_forwardref_mixed(self) -> None:
LocalIntOrStr = Union[int, str]

annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"]
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr]

# two different forwardrefs with same name
def test_forwardref_duplicate(self) -> None:
DuplicateAlias = int

# first, resolve an annotation where `DuplicateAlias` resolves to the local int
cache = {}
assert (
utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache)
== List[int]
)

# then, resolve an annotation where the globalns changes and `DuplicateAlias` resolves to something else
# (i.e. this should not resolve to `List[int]` despite {"DuplicateAlias": int} in the cache)
assert (
utils.resolve_annotation(
utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache
)
== List[str]
)


@pytest.mark.parametrize(
("dt", "style", "expected"),
[
Expand Down
26 changes: 26 additions & 0 deletions tests/utils_helper_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: MIT

"""Separate module file for some test_utils.py type annotation tests."""

import sys
from typing import TYPE_CHECKING, List, TypeVar, Union

version = sys.version_info # assign to variable to trick pyright

if TYPE_CHECKING:
from typing_extensions import TypeAliasType
elif version >= (3, 12):
# non-3.12 tests shouldn't be using this
from typing import TypeAliasType

if version >= (3, 12):
CoolUniqueIntOrStrAlias = Union[int, str]
ListWithForwardRefAlias = TypeAliasType(
"ListWithForwardRefAlias", List["CoolUniqueIntOrStrAlias"]
)

T = TypeVar("T")
GenericListAlias = TypeAliasType("GenericListAlias", List[T], type_params=(T,))

DuplicateAlias = str
ListWithDuplicateAlias = TypeAliasType("ListWithDuplicateAlias", List["DuplicateAlias"])