Skip to content

Commit

Permalink
feat(models): add to_dict & to_json helper methods (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Apr 9, 2024
1 parent 35aa8a1 commit 2f5193d
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 9 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def main() -> None:
print()

message = await stream.get_final_message()
print(message.model_dump_json(indent=2))
print(message.to_json())

asyncio.run(main())
```
Expand Down Expand Up @@ -231,10 +231,10 @@ For a more complete example see [`examples/vertex.py`](https://github.com/anthro

## Using types

Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev), which provide helper methods for things like:
Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like:

- Serializing back into JSON, `model.model_dump_json(indent=2, exclude_unset=True)`
- Converting to a dictionary, `model.model_dump(exclude_unset=True)`
- Serializing back into JSON, `model.to_json()`
- Converting to a dictionary, `model.to_dict()`

Typed requests and responses provide autocomplete and documentation within your editor. If you would like to see type errors in VS Code to help catch bugs earlier, set `python.analysis.typeCheckingMode` to `basic`.

Expand Down
2 changes: 1 addition & 1 deletion examples/messages_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def main() -> None:
# the context manager, as long as the entire stream was consumed
# inside of the context manager
accumulated = await stream.get_final_message()
print("accumulated message: ", accumulated.model_dump_json(indent=2))
print("accumulated message: ", accumulated.to_json())


asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/messages_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def main() -> None:
event_handler=MyStream,
) as stream:
accumulated = await stream.get_final_message()
print("accumulated message: ", accumulated.model_dump_json(indent=2))
print("accumulated message: ", accumulated.to_json())


asyncio.run(main())
4 changes: 2 additions & 2 deletions examples/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def sync_client() -> None:
}
],
)
print(message.model_dump_json(indent=2))
print(message.to_json())


async def async_client() -> None:
Expand All @@ -36,7 +36,7 @@ async def async_client() -> None:
}
],
)
print(message.model_dump_json(indent=2))
print(message.to_json())


sync_client()
Expand Down
2 changes: 1 addition & 1 deletion helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def main() -> None:
event_handler=MyStream,
) as stream:
message = await stream.get_final_message()
print("accumulated message: ", message.model_dump_json(indent=2))
print("accumulated message: ", message.to_json())

asyncio.run(main())
```
Expand Down
73 changes: 73 additions & 0 deletions src/anthropic/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,79 @@ def model_fields_set(self) -> set[str]:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore

def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
"""
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)

def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
"""
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)

@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
Expand Down
64 changes: 64 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,42 @@ class Model(BaseModel):
assert "resource_id" in m.model_fields_set


def test_to_dict() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)

m = Model(FOO="hello")
assert m.to_dict() == {"FOO": "hello"}
assert m.to_dict(use_api_names=False) == {"foo": "hello"}

m2 = Model()
assert m2.to_dict() == {}
assert m2.to_dict(exclude_unset=False) == {"FOO": None}
assert m2.to_dict(exclude_unset=False, exclude_none=True) == {}
assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {}

m3 = Model(FOO=None)
assert m3.to_dict() == {"FOO": None}
assert m3.to_dict(exclude_none=True) == {}
assert m3.to_dict(exclude_defaults=True) == {}

if PYDANTIC_V2:

class Model2(BaseModel):
created_at: datetime

time_str = "2024-03-21T11:39:01.275859"
m4 = Model2.construct(created_at=time_str)
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
assert m4.to_dict(mode="json") == {"created_at": time_str}
else:
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
m.to_dict(mode="json")

with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_dict(warnings=False)


def test_forwards_compat_model_dump_method() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)
Expand Down Expand Up @@ -532,6 +568,34 @@ class Model(BaseModel):
m.model_dump(warnings=False)


def test_to_json() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)

m = Model(FOO="hello")
assert json.loads(m.to_json()) == {"FOO": "hello"}
assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"}

if PYDANTIC_V2:
assert m.to_json(indent=None) == '{"FOO":"hello"}'
else:
assert m.to_json(indent=None) == '{"FOO": "hello"}'

m2 = Model()
assert json.loads(m2.to_json()) == {}
assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None}
assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {}
assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {}

m3 = Model(FOO=None)
assert json.loads(m3.to_json()) == {"FOO": None}
assert json.loads(m3.to_json(exclude_none=True)) == {}

if not PYDANTIC_V2:
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_json(warnings=False)


def test_forwards_compat_model_dump_json_method() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)
Expand Down

0 comments on commit 2f5193d

Please sign in to comment.