Skip to content

Commit

Permalink
Refactor annotation injection for known (often generic) types (pydant…
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Jul 29, 2024
1 parent 2d37b66 commit ee3e3b1
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 41 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ test-pydantic-extra-types: .pdm
git clone https://github.com/pydantic/pydantic-extra-types.git --single-branch
bash ./tests/test_pydantic_extra_types.sh

.PHONY: test-no-docs # Run all tests except the docs tests
test-no-docs: .pdm
pdm run pytest tests --ignore=tests/test_docs.py

.PHONY: all ## Run the standard set of checks performed in CI
all: lint typecheck codespell testcov

Expand Down
101 changes: 72 additions & 29 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import dataclasses
import datetime
import inspect
import os
import pathlib
import re
import sys
import typing
Expand Down Expand Up @@ -39,6 +41,7 @@
from uuid import UUID
from warnings import warn

import typing_extensions
from pydantic_core import (
CoreSchema,
MultiHostUrl,
Expand Down Expand Up @@ -122,6 +125,28 @@
FROZEN_SET_TYPES: list[type] = [frozenset, typing.FrozenSet, collections.abc.Set]
DICT_TYPES: list[type] = [dict, typing.Dict, collections.abc.MutableMapping, collections.abc.Mapping]
IP_TYPES: list[type] = [IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network]
SEQUENCE_TYPES: list[type] = [typing.Sequence, collections.abc.Sequence]
PATH_TYPES: list[type] = [
os.PathLike,
pathlib.Path,
pathlib.PurePath,
pathlib.PosixPath,
pathlib.PurePosixPath,
pathlib.PureWindowsPath,
]
MAPPING_TYPES = [
typing.Mapping,
typing.MutableMapping,
collections.abc.Mapping,
collections.abc.MutableMapping,
collections.OrderedDict,
typing_extensions.OrderedDict,
typing.DefaultDict,
collections.defaultdict,
collections.Counter,
typing.Counter,
]
DEQUE_TYPES: list[type] = [collections.deque, typing.Deque]


def check_validator_fields_against_field_name(
Expand Down Expand Up @@ -400,16 +425,16 @@ def str_schema(self) -> CoreSchema:

# the following methods can be overridden but should be considered
# unstable / private APIs
def _list_schema(self, tp: Any, items_type: Any) -> CoreSchema:
def _list_schema(self, items_type: Any) -> CoreSchema:
return core_schema.list_schema(self.generate_schema(items_type))

def _dict_schema(self, tp: Any, keys_type: Any, values_type: Any) -> CoreSchema:
def _dict_schema(self, keys_type: Any, values_type: Any) -> CoreSchema:
return core_schema.dict_schema(self.generate_schema(keys_type), self.generate_schema(values_type))

def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema:
def _set_schema(self, items_type: Any) -> CoreSchema:
return core_schema.set_schema(self.generate_schema(items_type))

def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema:
def _frozenset_schema(self, items_type: Any) -> CoreSchema:
return core_schema.frozenset_schema(self.generate_schema(items_type))

def _enum_schema(self, enum_type: type[Enum]) -> CoreSchema:
Expand Down Expand Up @@ -943,13 +968,15 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901
elif obj in TUPLE_TYPES:
return self._tuple_schema(obj)
elif obj in LIST_TYPES:
return self._list_schema(obj, Any)
return self._list_schema(Any)
elif obj in SET_TYPES:
return self._set_schema(obj, Any)
return self._set_schema(Any)
elif obj in FROZEN_SET_TYPES:
return self._frozenset_schema(obj, Any)
return self._frozenset_schema(Any)
elif obj in SEQUENCE_TYPES:
return self._sequence_schema(Any)
elif obj in DICT_TYPES:
return self._dict_schema(obj, Any, Any)
return self._dict_schema(Any, Any)
elif isinstance(obj, TypeAliasType):
return self._type_alias_type_schema(obj)
elif obj is type:
Expand Down Expand Up @@ -986,15 +1013,16 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901

if _typing_extra.is_dataclass(obj):
return self._dataclass_schema(obj, None)
res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
return self._apply_annotations(source_type, annotations)

origin = get_origin(obj)
if origin is not None:
return self._match_generic_type(obj, origin)

res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
return self._apply_annotations(source_type, annotations)

if self._arbitrary_types:
return self._arbitrary_type_schema(obj)
return self._unknown_type_schema(obj)
Expand All @@ -1021,24 +1049,29 @@ def _match_generic_type(self, obj: Any, origin: Any) -> CoreSchema: # noqa: C90
elif origin in TUPLE_TYPES:
return self._tuple_schema(obj)
elif origin in LIST_TYPES:
return self._list_schema(obj, self._get_first_arg_or_any(obj))
return self._list_schema(self._get_first_arg_or_any(obj))
elif origin in SET_TYPES:
return self._set_schema(obj, self._get_first_arg_or_any(obj))
return self._set_schema(self._get_first_arg_or_any(obj))
elif origin in FROZEN_SET_TYPES:
return self._frozenset_schema(obj, self._get_first_arg_or_any(obj))
return self._frozenset_schema(self._get_first_arg_or_any(obj))
elif origin in DICT_TYPES:
return self._dict_schema(obj, *self._get_first_two_args_or_any(obj))
return self._dict_schema(*self._get_first_two_args_or_any(obj))
elif is_typeddict(origin):
return self._typed_dict_schema(obj, origin)
elif origin in (typing.Type, type):
return self._subclass_schema(obj)
elif origin in {typing.Sequence, collections.abc.Sequence}:
return self._sequence_schema(obj)
elif origin in SEQUENCE_TYPES:
return self._sequence_schema(self._get_first_arg_or_any(obj))
elif origin in {typing.Iterable, collections.abc.Iterable, typing.Generator, collections.abc.Generator}:
return self._iterable_schema(obj)
elif origin in (re.Pattern, typing.Pattern):
return self._pattern_schema(obj)

res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
return self._apply_annotations(source_type, annotations)

if self._arbitrary_types:
return self._arbitrary_type_schema(origin)
return self._unknown_type_schema(obj)
Expand Down Expand Up @@ -1650,17 +1683,16 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:
else:
return core_schema.is_subclass_schema(type_param)

def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema:
def _sequence_schema(self, items_type: Any) -> core_schema.CoreSchema:
"""Generate schema for a Sequence, e.g. `Sequence[int]`."""
from ._std_types_schema import serialize_sequence_via_list

item_type = self._get_first_arg_or_any(sequence_type)
item_type_schema = self.generate_schema(item_type)
item_type_schema = self.generate_schema(items_type)
list_schema = core_schema.list_schema(item_type_schema)

json_schema = smart_deepcopy(list_schema)
python_schema = core_schema.is_instance_schema(typing.Sequence, cls_repr='Sequence')
if item_type != Any:
if items_type != Any:
from ._validators import sequence_validator

python_schema = core_schema.chain_schema(
Expand Down Expand Up @@ -1989,7 +2021,11 @@ def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema:
def _get_prepare_pydantic_annotations_for_known_type(
self, obj: Any, annotations: tuple[Any, ...]
) -> tuple[Any, list[Any]] | None:
from ._std_types_schema import PREPARE_METHODS
from ._std_types_schema import (
mapping_like_prepare_pydantic_annotations,
path_schema_prepare_pydantic_annotations,
sequence_like_prepare_pydantic_annotations,
)

# Check for hashability
try:
Expand All @@ -1998,12 +2034,19 @@ def _get_prepare_pydantic_annotations_for_known_type(
# obj is definitely not a known type if this fails
return None

for gen in PREPARE_METHODS:
res = gen(obj, annotations, self._config_wrapper.config_dict)
if res is not None:
return res

return None
# TODO: I'd rather we didn't handle the generic nature in the annotations prep, but the same way we do other
# generic types like list[str] via _match_generic_type, but I'm not sure if we can do that because this is
# not always called from match_type, but sometimes from _apply_annotations
obj_origin = get_origin(obj) or obj

if obj_origin in PATH_TYPES:
return path_schema_prepare_pydantic_annotations(obj, annotations)
elif obj_origin in DEQUE_TYPES:
return sequence_like_prepare_pydantic_annotations(obj, annotations)
elif obj_origin in MAPPING_TYPES:
return mapping_like_prepare_pydantic_annotations(obj, annotations)
else:
return None

def _apply_annotations(
self,
Expand Down
9 changes: 4 additions & 5 deletions pydantic/_internal/_std_types_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pydantic.fields import FieldInfo
from pydantic.types import Strict

from ..config import ConfigDict
from ..json_schema import JsonSchemaValue
from . import _known_annotated_metadata, _typing_extra
from ._internal_dataclass import slots_true
Expand Down Expand Up @@ -60,7 +59,7 @@ def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchem


def path_schema_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
source_type: Any, annotations: Iterable[Any]
) -> tuple[Any, list[Any]] | None:
import pathlib

Expand Down Expand Up @@ -272,7 +271,7 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH


def sequence_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
source_type: Any, annotations: Iterable[Any]
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)

Expand Down Expand Up @@ -443,7 +442,7 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH


def mapping_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
source_type: Any, annotations: Iterable[Any]
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)

Expand Down Expand Up @@ -477,7 +476,7 @@ def mapping_like_prepare_pydantic_annotations(
)


PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any]], tuple[Any, list[Any]] | None], ...] = (
sequence_like_prepare_pydantic_annotations,
mapping_like_prepare_pydantic_annotations,
path_schema_prepare_pydantic_annotations,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_fastapi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
set -x
set -e

# waiting on a fix for a bug introduced in v72.0.0, see https://github.com/pypa/setuptools/issues/4519
echo "PIP_CONSTRAINT=setuptools<72.0.0" >> $GITHUB_ENV

cd fastapi
git fetch --tags

Expand Down
1 change: 1 addition & 0 deletions tests/test_forward_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ class Account(BaseModel):
}


@pytest.mark.xfail(reason='needs more strict annotation checks, see https://github.com/pydantic/pydantic/issues/9988')
def test_forward_ref_with_field(create_module):
@create_module
def module():
Expand Down
9 changes: 2 additions & 7 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,7 @@ class MyModel(BaseModel):
assert json_schema['properties']['e']['enum'] == []


@pytest.mark.xfail(reason='needs more strict annotation checks, see https://github.com/pydantic/pydantic/issues/9988')
@pytest.mark.parametrize(
'kwargs,type_',
[
Expand All @@ -1806,13 +1807,7 @@ class Foo(BaseModel):
a: type_ = Field('foo', title='A title', description='A description', **kwargs)


@pytest.mark.xfail(
reason=(
"We allow python validation functions to wrap the type if the constraint isn't valid for the type. "
"An error isn't raised for an int or float annotated with this same invalid constraint, "
"so it's ok to mark this as xfail for now, but we should improve all of them in the future."
)
)
@pytest.mark.xfail(reason='needs more strict annotation checks, see https://github.com/pydantic/pydantic/issues/9988')
def test_invalid_decimal_constraint():
with pytest.raises(
TypeError, match="The following constraints cannot be applied to <class 'decimal.Decimal'>: 'max_length'"
Expand Down
1 change: 1 addition & 0 deletions tests/test_types_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class SubSelfRef(SelfRef):
}


@pytest.mark.xfail(reason='needs more strict annotation checks, see https://github.com/pydantic/pydantic/issues/9988')
def test_self_type_with_field(Self):
with pytest.raises(TypeError, match=r'The following constraints cannot be applied.*\'gt\''):

Expand Down

0 comments on commit ee3e3b1

Please sign in to comment.