diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 326e772909..75fbc87592 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -92,7 +92,7 @@ from ._mock_val_ser import MockCoreSchema from ._schema_generation_shared import CallbackGetCoreSchemaHandler from ._typing_extra import is_finalvar, is_self_type, is_zoneinfo_type -from ._utils import lenient_issubclass +from ._utils import lenient_issubclass, smart_deepcopy if TYPE_CHECKING: from ..fields import ComputedFieldInfo, FieldInfo @@ -1658,6 +1658,7 @@ def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema: item_type_schema = self.generate_schema(item_type) list_schema = core_schema.list_schema(item_type_schema) + json_schema = smart_deepcopy(list_schema) python_schema = core_schema.is_instance_schema(typing.Sequence, cls_repr='Sequence') if item_type != Any: from ._validators import sequence_validator @@ -1670,7 +1671,7 @@ def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema: serialize_sequence_via_list, schema=item_type_schema, info_arg=True ) return core_schema.json_or_python_schema( - json_schema=list_schema, python_schema=python_schema, serialization=serialization + json_schema=json_schema, python_schema=python_schema, serialization=serialization ) def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: diff --git a/tests/test_discriminated_union.py b/tests/test_discriminated_union.py index 804dcbb50c..68b3801713 100644 --- a/tests/test_discriminated_union.py +++ b/tests/test_discriminated_union.py @@ -14,6 +14,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from pydantic.errors import PydanticUserError from pydantic.fields import FieldInfo +from pydantic.functional_validators import model_validator from pydantic.json_schema import GenerateJsonSchema from pydantic.types import Tag @@ -1379,6 +1380,90 @@ class Model(BaseModel): } +def test_sequence_discriminated_union_validation(): + """ + Related issue: https://github.com/pydantic/pydantic/issues/9872 + """ + + class A(BaseModel): + type: Literal['a'] + a_field: str + + class B(BaseModel): + type: Literal['b'] + b_field: str + + class Model(BaseModel): + items: Sequence[Annotated[Union[A, B], Field(discriminator='type')]] + + import json + + data_json = '{"items": [{"type": "b"}]}' + data_dict = json.loads(data_json) + + expected_error = { + 'type': 'missing', + 'loc': ('items', 0, 'b', 'b_field'), + 'msg': 'Field required', + 'input': {'type': 'b'}, + } + + # missing field should be `b_field` only, not including `a_field` + # also `literal_error` should not be reported on `type` + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate(data_dict) + assert exc_info.value.errors(include_url=False) == [expected_error] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate_json(data_json) + assert exc_info.value.errors(include_url=False) == [expected_error] + + +def test_sequence_discriminated_union_validation_with_validator(): + """ + This is the same as the previous test, but add validators to both class. + """ + + class A(BaseModel): + type: Literal['a'] + a_field: str + + @model_validator(mode='after') + def check_a(self): + return self + + class B(BaseModel): + type: Literal['b'] + b_field: str + + @model_validator(mode='after') + def check_b(self): + return self + + class Model(BaseModel): + items: Sequence[Annotated[Union[A, B], Field(discriminator='type')]] + + import json + + data_json = '{"items": [{"type": "b"}]}' + data_dict = json.loads(data_json) + + expected_error = { + 'type': 'missing', + 'loc': ('items', 0, 'b', 'b_field'), + 'msg': 'Field required', + 'input': {'type': 'b'}, + } + + # missing field should be `b_field` only, not including `a_field` + # also `literal_error` should not be reported on `type` + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate(data_dict) + assert exc_info.value.errors(include_url=False) == [expected_error] + + @pytest.fixture(scope='session', name='animals') def callable_discriminated_union_animals() -> SimpleNamespace: class Cat(BaseModel):