Skip to content

Commit

Permalink
feat: add support for union (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored Oct 30, 2023
2 parents b50bcdb + db82cc6 commit ac2521b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ Is output as:
}
:class:`typing.Union`
~~~~~~~~~~~~~~~~~~~~~
:class:`typing.Union` and :class:`types.UnionType` (``X | Y``)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Avro schema: JSON array of multiple Avro schemas

Expand Down
6 changes: 6 additions & 0 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import enum
import inspect
import sys
import types
import uuid
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -572,6 +573,11 @@ class UnionSchema(Schema):
def handles_type(cls, py_type: Type) -> bool:
"""Whether this schema class can represent a given Python class"""
origin = get_origin(py_type)

# Support for `X | Y` syntax available in Python 3.10+
# equivalent to `typing.Union[X, Y]`
if getattr(types, "UnionType", None):
return origin == Union or origin == types.UnionType # noqa: E721
return origin == Union

def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, options: Option = Option(0)):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import enum
import re
import sys
from typing import (
Dict,
List,
Expand Down Expand Up @@ -137,6 +138,10 @@ def test_string_tuple():
expected = {"type": "array", "items": "string"}
assert_schema(py_type, expected)

py_type = tuple[str]
expected = {"type": "array", "items": "string"}
assert_schema(py_type, expected)


def test_string_sequence():
py_type = Sequence[str]
Expand Down Expand Up @@ -220,19 +225,43 @@ def test_string_dict_of_dicts():
}
assert_schema(py_type, expected)

py_type = dict[str, dict[str, str]]
expected = {
"type": "map",
"values": {
"type": "map",
"values": "string",
},
}
assert_schema(py_type, expected)


def test_union_string_int():
py_type = Union[str, int]
expected = ["string", "long"]
assert_schema(py_type, expected)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
def test_union_string_int_py310():
py_type = str | int
expected = ["string", "long"]
assert_schema(py_type, expected)


def test_union_string_string_int():
py_type = Union[str, str, int]
expected = ["string", "long"]
assert_schema(py_type, expected)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
def test_union_string_string_int_py310():
py_type = str | int | str
expected = ["string", "long"]
assert_schema(py_type, expected)


def test_union_of_union_string_int():
py_type = Union[str, Union[str, int]]
expected = ["string", "long"]
Expand All @@ -245,6 +274,13 @@ def test_optional_str():
assert_schema(py_type, expected)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
def test_optional_str_py310():
py_type = str | None
expected = ["string", "null"]
assert_schema(py_type, expected)


def test_enum():
class PyType(enum.Enum):
RED = "RED"
Expand Down

0 comments on commit ac2521b

Please sign in to comment.