From 7b577c5faaaa6e6a7cf59bd4a0835411f6b8635f Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:28:04 +0100 Subject: [PATCH] chore(internal): improve deserialisation of discriminated unions (#386) --- src/anthropic/_models.py | 160 ++++++++++++++++++++++++- src/anthropic/_utils/_transform.py | 5 +- tests/test_models.py | 180 +++++++++++++++++++++++++++++ 3 files changed, 343 insertions(+), 2 deletions(-) diff --git a/src/anthropic/_models.py b/src/anthropic/_models.py index af68e618..88afa408 100644 --- a/src/anthropic/_models.py +++ b/src/anthropic/_models.py @@ -10,6 +10,7 @@ Protocol, Required, TypedDict, + TypeGuard, final, override, runtime_checkable, @@ -31,6 +32,7 @@ HttpxRequestFiles, ) from ._utils import ( + PropertyInfo, is_list, is_given, is_mapping, @@ -39,6 +41,7 @@ strip_not_given, extract_type_arg, is_annotated_type, + strip_annotated_type, ) from ._compat import ( PYDANTIC_V2, @@ -55,6 +58,9 @@ ) from ._constants import RAW_RESPONSE_HEADER +if TYPE_CHECKING: + from pydantic_core.core_schema import ModelField, ModelFieldsSchema + __all__ = ["BaseModel", "GenericModel"] _T = TypeVar("_T") @@ -268,7 +274,6 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: def is_basemodel(type_: type) -> bool: """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" - origin = get_origin(type_) or type_ if is_union(type_): for variant in get_args(type_): if is_basemodel(variant): @@ -276,6 +281,11 @@ def is_basemodel(type_: type) -> bool: return False + return is_basemodel_type(type_) + + +def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: + origin = get_origin(type_) or type_ return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) @@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object: """ # unwrap `Annotated[T, ...]` -> `T` if is_annotated_type(type_): + meta = get_args(type_)[1:] type_ = extract_type_arg(type_, 0) + else: + meta = tuple() # we need to use the origin class for any types that are subscripted generics # e.g. Dict[str, object] @@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object: except Exception: pass + # if the type is a discriminated union then we want to construct the right variant + # in the union, even if the data doesn't match exactly, otherwise we'd break code + # that relies on the constructed class types, e.g. + # + # class FooType: + # kind: Literal['foo'] + # value: str + # + # class BarType: + # kind: Literal['bar'] + # value: int + # + # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then + # we'd end up constructing `FooType` when it should be `BarType`. + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type: + return construct_type(type_=variant_type, value=value) + # if the data is not valid, use the first variant that doesn't fail while deserializing for variant in args: try: @@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object: return value +@runtime_checkable +class CachedDiscriminatorType(Protocol): + __discriminator__: DiscriminatorDetails + + +class DiscriminatorDetails: + field_name: str + """The name of the discriminator field in the variant class, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] + ``` + + Will result in field_name='type' + """ + + field_alias_from: str | None + """The name of the discriminator field in the API response, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] = Field(alias='type_from_api') + ``` + + Will result in field_alias_from='type_from_api' + """ + + mapping: dict[str, type] + """Mapping of discriminator value to variant type, e.g. + + {'foo': FooVariant, 'bar': BarVariant} + """ + + def __init__( + self, + *, + mapping: dict[str, type], + discriminator_field: str, + discriminator_alias: str | None, + ) -> None: + self.mapping = mapping + self.field_name = discriminator_field + self.field_alias_from = discriminator_alias + + +def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: + if isinstance(union, CachedDiscriminatorType): + return union.__discriminator__ + + discriminator_field_name: str | None = None + + for annotation in meta_annotations: + if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + discriminator_field_name = annotation.discriminator + break + + if not discriminator_field_name: + return None + + mapping: dict[str, type] = {} + discriminator_alias: str | None = None + + for variant in get_args(union): + variant = strip_annotated_type(variant) + if is_basemodel_type(variant): + if PYDANTIC_V2: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in field_schema["expected"]: + if isinstance(entry, str): + mapping[entry] = variant + else: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field_info.alias + + if field_info.annotation and is_literal_type(field_info.annotation): + for entry in get_args(field_info.annotation): + if isinstance(entry, str): + mapping[entry] = variant + + if not mapping: + return None + + details = DiscriminatorDetails( + mapping=mapping, + discriminator_field=discriminator_field_name, + discriminator_alias=discriminator_alias, + ) + cast(CachedDiscriminatorType, union).__discriminator__ = details + return details + + +def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: + schema = model.__pydantic_core_schema__ + if schema["type"] != "model": + return None + + fields_schema = schema["schema"] + if fields_schema["type"] != "model-fields": + return None + + fields_schema = cast("ModelFieldsSchema", fields_schema) + + field = fields_schema["fields"].get(field_name) + if not field: + return None + + return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] + + def validate_type(*, type_: type[_T], value: object) -> _T: """Strict validation that the given value matches the expected type""" if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): diff --git a/src/anthropic/_utils/_transform.py b/src/anthropic/_utils/_transform.py index 1bd1330c..47e262a5 100644 --- a/src/anthropic/_utils/_transform.py +++ b/src/anthropic/_utils/_transform.py @@ -51,6 +51,7 @@ class MyParams(TypedDict): alias: str | None format: PropertyFormat | None format_template: str | None + discriminator: str | None def __init__( self, @@ -58,14 +59,16 @@ def __init__( alias: str | None = None, format: PropertyFormat | None = None, format_template: str | None = None, + discriminator: str | None = None, ) -> None: self.alias = alias self.format = format self.format_template = format_template + self.discriminator = discriminator @override def __repr__(self) -> str: - return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')" + return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" def maybe_transform( diff --git a/tests/test_models.py b/tests/test_models.py index ff16455c..26d4696b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,6 +7,7 @@ import pydantic from pydantic import Field +from anthropic._utils import PropertyInfo from anthropic._compat import PYDANTIC_V2, parse_obj, model_dump, model_json from anthropic._models import BaseModel, construct_type @@ -583,3 +584,182 @@ class Model(BaseModel): ) assert isinstance(m, Model) assert m.value == "foo" + + +def test_discriminated_unions_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, A) + assert m.type == "a" + if PYDANTIC_V2: + assert m.data == 100 # type: ignore[comparison-overlap] + else: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + + +def test_discriminated_unions_unknown_variant() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "c", "data": None, "new_thing": "bar"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + + # just chooses the first variant + assert isinstance(m, A) + assert m.type == "c" # type: ignore[comparison-overlap] + assert m.data == None # type: ignore[unreachable] + assert m.new_thing == "bar" + + +def test_discriminated_unions_invalid_data_nested_unions() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + class C(BaseModel): + type: Literal["c"] + + data: bool + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "c", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, C) + assert m.type == "c" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_with_aliases_invalid_data() -> None: + class A(BaseModel): + foo_type: Literal["a"] = Field(alias="type") + + data: str + + class B(BaseModel): + foo_type: Literal["b"] = Field(alias="type") + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, B) + assert m.foo_type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, A) + assert m.foo_type == "a" + if PYDANTIC_V2: + assert m.data == 100 # type: ignore[comparison-overlap] + else: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + + +def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["a"] + + data: int + + m = construct_type( + value={"type": "a", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "a" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_invalid_data_uses_cache() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + UnionType = cast(Any, Union[A, B]) + + assert not hasattr(UnionType, "__discriminator__") + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + discriminator = UnionType.__discriminator__ + assert discriminator is not None + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + # if the discriminator details object stays the same between invocations then + # we hit the cache + assert UnionType.__discriminator__ is discriminator