diff --git a/sanic_ext/extensions/openapi/definitions.py b/sanic_ext/extensions/openapi/definitions.py index 7603386..e29c66f 100644 --- a/sanic_ext/extensions/openapi/definitions.py +++ b/sanic_ext/extensions/openapi/definitions.py @@ -20,7 +20,11 @@ from sanic.exceptions import SanicException -from sanic_ext.utils.typing import contains_annotations, is_pydantic +from sanic_ext.utils.typing import ( + contains_annotations, + is_msgspec, + is_pydantic, +) from .types import Definition, Schema @@ -409,7 +413,16 @@ def Component( if not spec.has_component(field, name): prop_info = hints[field] type_ = prop_info.__args__[1] - if is_pydantic(obj): + if is_msgspec(obj): + import msgspec + + _, definitions = msgspec.json.schema_components( + [obj], ref_template="#/components/schemas/{name}" + ) + if definitions: + for key, value in definitions.items(): + spec.add_component(field, key, value) + elif is_pydantic(obj): try: schema = obj.schema except AttributeError: @@ -419,12 +432,12 @@ def Component( if definitions: for key, value in definitions.items(): spec.add_component(field, key, value) + spec.add_component(field, name, component) else: component = ( type_.make(obj) if hasattr(type_, "make") else type_(obj) ) - - spec.add_component(field, name, component) + spec.add_component(field, name, component) return ref diff --git a/sanic_ext/extras/validation/setup.py b/sanic_ext/extras/validation/setup.py index 4b83d56..b869273 100644 --- a/sanic_ext/extras/validation/setup.py +++ b/sanic_ext/extras/validation/setup.py @@ -4,10 +4,11 @@ from sanic.log import logger from sanic_ext.exceptions import ValidationError -from sanic_ext.utils.typing import is_pydantic +from sanic_ext.utils.typing import is_msgspec, is_pydantic from .schema import make_schema from .validators import ( + _msgspec_validate_instance, _validate_annotations, _validate_instance, validate_body, @@ -46,7 +47,7 @@ async def do_validation( def generate_schema(param): try: - if param is None or is_pydantic(param): + if param is None or is_msgspec(param) or is_pydantic(param): return param except TypeError: ... @@ -55,7 +56,9 @@ def generate_schema(param): def _get_validator(model, schema, allow_multiple, allow_coerce): - if is_pydantic(model): + if is_msgspec(model): + return partial(_msgspec_validate_instance, allow_coerce=allow_coerce) + elif is_pydantic(model): return partial(_validate_instance, allow_coerce=allow_coerce) return partial( diff --git a/sanic_ext/extras/validation/validators.py b/sanic_ext/extras/validation/validators.py index 4bcfbb0..c3b2276 100644 --- a/sanic_ext/extras/validation/validators.py +++ b/sanic_ext/extras/validation/validators.py @@ -31,6 +31,18 @@ def validate_body( ) +def _msgspec_validate_instance(model, body, allow_coerce): + import msgspec + + try: + data = clean_data(model, body) if allow_coerce else body + return msgspec.convert(data, model) + except msgspec.ValidationError as e: + # Convert msgspec.ValidationError into TypeError for consistent + # behaviour with _validate_instance + raise TypeError(str(e)) + + def _validate_instance(model, body, allow_coerce): data = clean_data(model, body) if allow_coerce else body return model(**data) diff --git a/tests/extensions/openapi/test_model_spec.py b/tests/extensions/openapi/test_model_spec.py index 8563509..b140241 100644 --- a/tests/extensions/openapi/test_model_spec.py +++ b/tests/extensions/openapi/test_model_spec.py @@ -4,6 +4,7 @@ import attrs import pytest +from msgspec import Struct from pydantic import BaseModel from pydantic.dataclasses import dataclass as pydataclass @@ -34,6 +35,16 @@ class AlertResponsePydanticBaseModel(BaseModel): rule_id: str +class AlertMsgspecBaseModel(Struct): + hit: Dict[str, int] + last_updated: datetime + + +class AlertResponseMsgspecBaseModel(Struct): + alert: AlertMsgspecBaseModel + rule_id: str + + @pydataclass class AlertPydanticDataclass: hit: Dict[str, int] @@ -63,6 +74,7 @@ class AlertResponseAttrs: ( (AlertResponseDataclass, False), (AlertResponseAttrs, False), + (AlertResponseMsgspecBaseModel, True), (AlertResponsePydanticBaseModel, True), (AlertResponsePydanticDataclass, True), ), diff --git a/tests/extra/test_validation.py b/tests/extra/test_validation.py index b1c8eb9..8da1bf7 100644 --- a/tests/extra/test_validation.py +++ b/tests/extra/test_validation.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from sys import version_info from typing import Literal, Union @@ -8,6 +7,45 @@ from sanic_ext import validate +def _dataclass_spec(annotation): + from dataclasses import dataclass + + @dataclass + class Spec: + name: annotation + + return Spec + + +def _attrs_spec(annotation): + import attrs + + @attrs.define + class Spec: + name: annotation + + return Spec + + +def _msgspec_spec(annotation): + from msgspec import Struct + + class Spec(Struct): + name: annotation + + return Spec + + +def _pydantic_spec(annotation): + from pydantic.dataclasses import dataclass as pydataclass + + @pydataclass + class Spec: + name: annotation + + return Spec + + @pytest.mark.parametrize( "annotation", ( @@ -18,10 +56,17 @@ ) ), ) -def test_literal(app, annotation): - @dataclass - class Spec: - name: annotation +@pytest.mark.parametrize( + "spec_builder", + ( + _dataclass_spec, + _attrs_spec, + _msgspec_spec, + _pydantic_spec, + ), +) +def test_literal(app, annotation, spec_builder): + Spec = spec_builder(annotation) @app.get("/") @validate(query=Spec) @@ -33,10 +78,17 @@ def route(_, query: Spec): @pytest.mark.skipif(version_info < (3, 10), reason="Not needed on 3.10") -def test_literal_3_10(app): - @dataclass - class Spec: - name: Literal["foo"] | Literal["bar"] +@pytest.mark.parametrize( + "spec_builder", + ( + _dataclass_spec, + _attrs_spec, + _msgspec_spec, + _pydantic_spec, + ), +) +def test_literal_3_10(app, spec_builder): + Spec = spec_builder(Literal["foo"] | Literal["bar"]) @app.get("/") @validate(query=Spec)