Skip to content

Commit

Permalink
feat(commands): support 3.12's type statement and TypeAliasType (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv authored Dec 8, 2023
1 parent 594c12a commit 85cf393
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 9 deletions.
1 change: 1 addition & 0 deletions changelog/1128.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Support Python 3.12's ``type`` statement and :class:`py:typing.TypeAliasType` annotations in command signatures.
48 changes: 41 additions & 7 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,24 @@ def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)


def _resolve_typealiastype(
tp: Any, globals: Dict[str, Any], locals: Dict[str, Any], cache: Dict[str, Any]
):
# Use __module__ to get the (global) namespace in which the type alias was defined.
if mod := sys.modules.get(tp.__module__):
mod_globals = mod.__dict__
if mod_globals is not globals or mod_globals is not locals:
# if the namespace changed (usually when a TypeAliasType was imported from a different module),
# drop the cache since names can resolve differently now
cache = {}
globals = locals = mod_globals

# Accessing `__value__` automatically evaluates the type alias in the annotation scope.
# (recurse to resolve possible forwardrefs, aliases, etc.)
return evaluate_annotation(tp.__value__, globals, locals, cache)


# FIXME: this should be split up into smaller functions for clarity and easier maintenance
def evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
Expand All @@ -1147,23 +1165,31 @@ def evaluate_annotation(
cache[tp] = evaluated
return evaluated

# GenericAlias / UnionType
if hasattr(tp, "__args__"):
implicit_str = True
is_literal = False
orig_args = args = tp.__args__
if not hasattr(tp, "__origin__"):
if tp.__class__ is UnionType:
converted = Union[args] # type: ignore
converted = Union[tp.__args__] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)

return tp
if tp.__origin__ is Union:

implicit_str = True
is_literal = False
orig_args = args = tp.__args__
orig_origin = origin = tp.__origin__

# origin can be a TypeAliasType too, resolve it and continue
if hasattr(origin, "__value__"):
origin = _resolve_typealiastype(origin, globals, locals, cache)

if origin is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if origin is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
Expand All @@ -1179,13 +1205,21 @@ def evaluate_annotation(
):
raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")

if origin != orig_origin:
# we can't use `copy_with` in this case, so just skip all of the following logic
return origin[evaluated_args]

if evaluated_args == orig_args:
return tp

try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return origin[evaluated_args]

# TypeAliasType, 3.12+
if hasattr(tp, "__value__"):
return _resolve_typealiastype(tp, globals, locals, cache)

return tp

Expand Down
69 changes: 67 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from dataclasses import dataclass
from datetime import timedelta, timezone
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union
from unittest import mock

import pytest
Expand All @@ -18,7 +18,13 @@
import disnake
from disnake import utils

from . import helpers
from . import helpers, utils_helper_module

if TYPE_CHECKING:
from typing_extensions import TypeAliasType
elif sys.version_info >= (3, 12):
# non-3.12 tests shouldn't be using this
from typing import TypeAliasType


def test_missing() -> None:
Expand Down Expand Up @@ -785,6 +791,65 @@ def test_resolve_annotation_literal() -> None:
utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {}) # type: ignore


@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12")
class TestResolveAnnotationTypeAliasType:
def test_simple(self) -> None:
# this is equivalent to `type CoolList = List[int]`
CoolList = TypeAliasType("CoolList", List[int])
assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int]

def test_generic(self) -> None:
# this is equivalent to `type CoolList[T] = List[T]; CoolList[int]`
T = TypeVar("T")
CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))

annotation = CoolList[int]
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int]

# alias and arg in local scope
def test_forwardref_local(self) -> None:
T = TypeVar("T")
IntOrStr = Union[int, str]
CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))

annotation = CoolList["IntOrStr"]
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr]

# alias and arg in other module scope
def test_forwardref_module(self) -> None:
resolved = utils.resolve_annotation(
utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {}
)
assert resolved == List[Union[int, str]]

# combination of the previous two, alias in other module scope and arg in local scope
def test_forwardref_mixed(self) -> None:
LocalIntOrStr = Union[int, str]

annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"]
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr]

# two different forwardrefs with same name
def test_forwardref_duplicate(self) -> None:
DuplicateAlias = int

# first, resolve an annotation where `DuplicateAlias` resolves to the local int
cache = {}
assert (
utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache)
== List[int]
)

# then, resolve an annotation where the globalns changes and `DuplicateAlias` resolves to something else
# (i.e. this should not resolve to `List[int]` despite {"DuplicateAlias": int} in the cache)
assert (
utils.resolve_annotation(
utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache
)
== List[str]
)


@pytest.mark.parametrize(
("dt", "style", "expected"),
[
Expand Down
26 changes: 26 additions & 0 deletions tests/utils_helper_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: MIT

"""Separate module file for some test_utils.py type annotation tests."""

import sys
from typing import TYPE_CHECKING, List, TypeVar, Union

version = sys.version_info # assign to variable to trick pyright

if TYPE_CHECKING:
from typing_extensions import TypeAliasType
elif version >= (3, 12):
# non-3.12 tests shouldn't be using this
from typing import TypeAliasType

if version >= (3, 12):
CoolUniqueIntOrStrAlias = Union[int, str]
ListWithForwardRefAlias = TypeAliasType(
"ListWithForwardRefAlias", List["CoolUniqueIntOrStrAlias"]
)

T = TypeVar("T")
GenericListAlias = TypeAliasType("GenericListAlias", List[T], type_params=(T,))

DuplicateAlias = str
ListWithDuplicateAlias = TypeAliasType("ListWithDuplicateAlias", List["DuplicateAlias"])

0 comments on commit 85cf393

Please sign in to comment.