Skip to content

Commit

Permalink
chore(internal): improve deserialisation of discriminated unions (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Mar 11, 2024
1 parent 9f28a8f commit 7b577c5
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 2 deletions.
160 changes: 159 additions & 1 deletion src/anthropic/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Protocol,
Required,
TypedDict,
TypeGuard,
final,
override,
runtime_checkable,
Expand All @@ -31,6 +32,7 @@
HttpxRequestFiles,
)
from ._utils import (
PropertyInfo,
is_list,
is_given,
is_mapping,
Expand All @@ -39,6 +41,7 @@
strip_not_given,
extract_type_arg,
is_annotated_type,
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
Expand All @@ -55,6 +58,9 @@
)
from ._constants import RAW_RESPONSE_HEADER

if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, ModelFieldsSchema

__all__ = ["BaseModel", "GenericModel"]

_T = TypeVar("_T")
Expand Down Expand Up @@ -268,14 +274,18 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:

def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
origin = get_origin(type_) or type_
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True

return False

return is_basemodel_type(type_)


def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)


Expand All @@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
"""
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()

# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
Expand All @@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
except Exception:
pass

# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)

# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
Expand Down Expand Up @@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
return value


@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails


class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""

field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""

mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""

def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias


def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__

discriminator_field_name: str | None = None

for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break

if not discriminator_field_name:
return None

mapping: dict[str, type] = {}
discriminator_alias: str | None = None

for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue

# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")

field_schema = field["schema"]

if field_schema["type"] == "literal":
for entry in field_schema["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue

# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias

if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if isinstance(entry, str):
mapping[entry] = variant

if not mapping:
return None

details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details


def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None

fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None

fields_schema = cast("ModelFieldsSchema", fields_schema)

field = fields_schema["fields"].get(field_name)
if not field:
return None

return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]


def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion src/anthropic/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,24 @@ class MyParams(TypedDict):
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None

def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')"
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"


def maybe_transform(
Expand Down
Loading

0 comments on commit 7b577c5

Please sign in to comment.