Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add msgspec support for openapi.Component and @validate() #229

Merged
merged 6 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions sanic_ext/extensions/openapi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

from sanic.exceptions import SanicException

from sanic_ext.utils.typing import contains_annotations, is_pydantic
from sanic_ext.utils.typing import (
contains_annotations,
is_msgspec,
is_pydantic,
)

from .types import Definition, Schema

Expand Down Expand Up @@ -409,7 +413,16 @@ def Component(
if not spec.has_component(field, name):
prop_info = hints[field]
type_ = prop_info.__args__[1]
if is_pydantic(obj):
if is_msgspec(obj):
import msgspec

_, definitions = msgspec.json.schema_components(
[obj], ref_template="#/components/schemas/{name}"
)
if definitions:
for key, value in definitions.items():
spec.add_component(field, key, value)
elif is_pydantic(obj):
try:
schema = obj.schema
except AttributeError:
Expand All @@ -419,12 +432,12 @@ def Component(
if definitions:
for key, value in definitions.items():
spec.add_component(field, key, value)
spec.add_component(field, name, component)
else:
component = (
type_.make(obj) if hasattr(type_, "make") else type_(obj)
)

spec.add_component(field, name, component)
spec.add_component(field, name, component)

return ref

Expand Down
9 changes: 6 additions & 3 deletions sanic_ext/extras/validation/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from sanic.log import logger

from sanic_ext.exceptions import ValidationError
from sanic_ext.utils.typing import is_pydantic
from sanic_ext.utils.typing import is_msgspec, is_pydantic

from .schema import make_schema
from .validators import (
_msgspec_validate_instance,
_validate_annotations,
_validate_instance,
validate_body,
Expand Down Expand Up @@ -46,7 +47,7 @@ async def do_validation(

def generate_schema(param):
try:
if param is None or is_pydantic(param):
if param is None or is_msgspec(param) or is_pydantic(param):
return param
except TypeError:
...
Expand All @@ -55,7 +56,9 @@ def generate_schema(param):


def _get_validator(model, schema, allow_multiple, allow_coerce):
if is_pydantic(model):
if is_msgspec(model):
return partial(_msgspec_validate_instance, allow_coerce=allow_coerce)
elif is_pydantic(model):
return partial(_validate_instance, allow_coerce=allow_coerce)

return partial(
Expand Down
12 changes: 12 additions & 0 deletions sanic_ext/extras/validation/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def validate_body(
)


def _msgspec_validate_instance(model, body, allow_coerce):
import msgspec

try:
data = clean_data(model, body) if allow_coerce else body
return msgspec.convert(data, model)
except msgspec.ValidationError as e:
# Convert msgspec.ValidationError into TypeError for consistent
# behaviour with _validate_instance
raise TypeError(str(e))


def _validate_instance(model, body, allow_coerce):
data = clean_data(model, body) if allow_coerce else body
return model(**data)
Expand Down
12 changes: 12 additions & 0 deletions tests/extensions/openapi/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import attrs
import pytest
from msgspec import Struct
from pydantic import BaseModel
from pydantic.dataclasses import dataclass as pydataclass

Expand Down Expand Up @@ -34,6 +35,16 @@ class AlertResponsePydanticBaseModel(BaseModel):
rule_id: str


class AlertMsgspecBaseModel(Struct):
hit: Dict[str, int]
last_updated: datetime


class AlertResponseMsgspecBaseModel(Struct):
alert: AlertMsgspecBaseModel
rule_id: str


@pydataclass
class AlertPydanticDataclass:
hit: Dict[str, int]
Expand Down Expand Up @@ -63,6 +74,7 @@ class AlertResponseAttrs:
(
(AlertResponseDataclass, False),
(AlertResponseAttrs, False),
(AlertResponseMsgspecBaseModel, True),
(AlertResponsePydanticBaseModel, True),
(AlertResponsePydanticDataclass, True),
),
Expand Down
70 changes: 61 additions & 9 deletions tests/extra/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from sys import version_info
from typing import Literal, Union

Expand All @@ -8,6 +7,45 @@
from sanic_ext import validate


def _dataclass_spec(annotation):
from dataclasses import dataclass

@dataclass
class Spec:
name: annotation

return Spec


def _attrs_spec(annotation):
import attrs

@attrs.define
class Spec:
name: annotation

return Spec


def _msgspec_spec(annotation):
from msgspec import Struct

class Spec(Struct):
name: annotation

return Spec


def _pydantic_spec(annotation):
from pydantic.dataclasses import dataclass as pydataclass

@pydataclass
class Spec:
name: annotation

return Spec


@pytest.mark.parametrize(
"annotation",
(
Expand All @@ -18,10 +56,17 @@
)
),
)
def test_literal(app, annotation):
@dataclass
class Spec:
name: annotation
@pytest.mark.parametrize(
"spec_builder",
(
_dataclass_spec,
_attrs_spec,
_msgspec_spec,
_pydantic_spec,
),
)
def test_literal(app, annotation, spec_builder):
Spec = spec_builder(annotation)

@app.get("/")
@validate(query=Spec)
Expand All @@ -33,10 +78,17 @@ def route(_, query: Spec):


@pytest.mark.skipif(version_info < (3, 10), reason="Not needed on 3.10")
def test_literal_3_10(app):
@dataclass
class Spec:
name: Literal["foo"] | Literal["bar"]
@pytest.mark.parametrize(
"spec_builder",
(
_dataclass_spec,
_attrs_spec,
_msgspec_spec,
_pydantic_spec,
),
)
def test_literal_3_10(app, spec_builder):
Spec = spec_builder(Literal["foo"] | Literal["bar"])

@app.get("/")
@validate(query=Spec)
Expand Down
Loading