diff --git a/docs/types.rst b/docs/types.rst index ea4cb14..ca2bfb8 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -186,8 +186,8 @@ Is output as: } -:class:`typing.Union` -~~~~~~~~~~~~~~~~~~~~~ +:class:`typing.Union` and :class:`types.UnionType` (``X | Y``) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Avro schema: JSON array of multiple Avro schemas diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 2759c20..41d5663 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -24,6 +24,7 @@ import enum import inspect import sys +import types import uuid from typing import ( TYPE_CHECKING, @@ -572,6 +573,11 @@ class UnionSchema(Schema): def handles_type(cls, py_type: Type) -> bool: """Whether this schema class can represent a given Python class""" origin = get_origin(py_type) + + # Support for `X | Y` syntax available in Python 3.10+ + # equivalent to `typing.Union[X, Y]` + if getattr(types, "UnionType", None): + return origin == Union or origin == types.UnionType # noqa: E721 return origin == Union def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, options: Option = Option(0)): diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 2d2ee6c..191d742 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -11,6 +11,7 @@ import enum import re +import sys from typing import ( Dict, List, @@ -137,6 +138,10 @@ def test_string_tuple(): expected = {"type": "array", "items": "string"} assert_schema(py_type, expected) + py_type = tuple[str] + expected = {"type": "array", "items": "string"} + assert_schema(py_type, expected) + def test_string_sequence(): py_type = Sequence[str] @@ -220,6 +225,16 @@ def test_string_dict_of_dicts(): } assert_schema(py_type, expected) + py_type = dict[str, dict[str, str]] + expected = { + "type": "map", + "values": { + "type": "map", + "values": "string", + }, + } + assert_schema(py_type, expected) + def test_union_string_int(): py_type = Union[str, int] @@ -227,12 +242,26 @@ def test_union_string_int(): assert_schema(py_type, expected) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+") +def test_union_string_int_py310(): + py_type = str | int + expected = ["string", "long"] + assert_schema(py_type, expected) + + def test_union_string_string_int(): py_type = Union[str, str, int] expected = ["string", "long"] assert_schema(py_type, expected) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+") +def test_union_string_string_int_py310(): + py_type = str | int | str + expected = ["string", "long"] + assert_schema(py_type, expected) + + def test_union_of_union_string_int(): py_type = Union[str, Union[str, int]] expected = ["string", "long"] @@ -245,6 +274,13 @@ def test_optional_str(): assert_schema(py_type, expected) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+") +def test_optional_str_py310(): + py_type = str | None + expected = ["string", "null"] + assert_schema(py_type, expected) + + def test_enum(): class PyType(enum.Enum): RED = "RED"