Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for form encoded bodies #18

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions django_api_decorator/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic_core import PydanticUndefined

from .types import ApiMeta, FieldError, PublicAPIError
from .utils import get_list_fields, parse_form_encoded_body

P = typing.ParamSpec("P")
T = typing.TypeVar("T")
Expand Down Expand Up @@ -101,9 +102,9 @@ def decorator(func: Callable[..., Any]) -> Callable[..., HttpResponse]:

# If the method has a "body" argument, get a function to call to parse
# the request body into the type expected by the view.
body_adapter = None
if "body" in signature.parameters:
body_adapter = _get_body_adapter(parameter=signature.parameters["body"])
list_fields, body_adapter = set[str](), None
if body_annotation := signature.parameters.get("body"):
list_fields, body_adapter = _get_body_adapter(body_annotation)

# Get a function to use for encoding the value returned from the view
# into a request we can return to the client.
Expand All @@ -126,7 +127,14 @@ def inner(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# Parse the request body if the request method allows a body and the
# view has requested that we should parse the body.
if _can_have_body(request.method) and body_adapter:
extra_kwargs["body"] = body_adapter.validate_json(request.body)
if request.content_type in {
"application/x-www-form-urlencoded",
"multipart/form-data",
}:
data = parse_form_encoded_body(request, list_fields)
extra_kwargs["body"] = body_adapter.validate_python(data)
else:
extra_kwargs["body"] = body_adapter.validate_json(request.body)

# Parse query params and add them to the parameters given to the view.
raw_query_params: dict[str, Any] = {}
Expand Down Expand Up @@ -261,12 +269,18 @@ def _can_have_body(method: str | None) -> bool:
return method in ("POST", "PATCH", "PUT")


def _get_body_adapter(*, parameter: inspect.Parameter) -> pydantic.TypeAdapter[Any]:
def _get_body_adapter(
parameter: inspect.Parameter,
) -> tuple[set[str], pydantic.TypeAdapter[Any]]:
annotation = parameter.annotation
if annotation is inspect.Parameter.empty:
raise TypeError("The body parameter must have a type annotation")

return pydantic.TypeAdapter(annotation)
list_fields = set()
if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel):
list_fields = get_list_fields(annotation)

return list_fields, pydantic.TypeAdapter(annotation)


#####################
Expand Down
34 changes: 33 additions & 1 deletion django_api_decorator/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Callable, ParamSpec, cast
from types import UnionType
from typing import Any, Callable, List, ParamSpec, Set, Tuple, Union, cast, get_origin

from django.http import HttpRequest, HttpResponse, HttpResponseNotAllowed
from pydantic import BaseModel

P = ParamSpec("P")

Expand Down Expand Up @@ -58,3 +60,33 @@ def call_view(
call_view._method_router_views = views # type: ignore[attr-defined]

return cast(Callable[..., HttpResponse], call_view)


def is_list_type(annotation: Any) -> bool:
"""
Check if the given annotation is a collection type.
"""

origin = get_origin(annotation)
return origin in (Union, UnionType, List, Tuple, Set, list, set, tuple)


def get_list_fields(model: type[BaseModel]) -> set[str]:
return {
name
for name, field in model.model_fields.items()
if is_list_type(field.annotation)
}


def parse_form_encoded_body(
request: HttpRequest, list_fields: set[str]
) -> dict[str, str | list[str] | None]:
"""
Convert request.POST to a format pydantic can take as input
"""

return {
key: request.POST.getlist(key) if key in list_fields else request.POST.get(key)
for key in request.POST
}
142 changes: 116 additions & 26 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def post_view(request: HttpRequest) -> JsonResponse:
path("get", get_view),
path("post", post_view),
]

mocker.patch(f"{__name__}.urlpatterns", urls)

response = client.get("/get")
Expand Down Expand Up @@ -69,7 +68,6 @@ def noauth_anonymous_view(request: HttpRequest) -> JsonResponse:
path("noauth-user", noauth_user_view),
path("noauth-anonymous", noauth_anonymous_view),
]

mocker.patch(f"{__name__}.urlpatterns", urls)

response = client.get("/auth-user")
Expand Down Expand Up @@ -173,12 +171,10 @@ def test_query_params( # type: ignore
mocker, client, settings, view, have_url, want_status, want_values
):
collector, api_view = _create_api_view(view, ["query_param"]) # type: ignore
urls = [
path("", api_view),
]

urls = [path("", api_view)]
settings.ROOT_URLCONF = __name__
mocker.patch(f"{__name__}.urlpatterns", urls)

got = client.get(have_url)
assert got.status_code == want_status
if want_values is None:
Expand Down Expand Up @@ -213,9 +209,7 @@ def test_path_params( # type: ignore
"""

collector, api_view = _create_api_view(view, None) # type: ignore
urls = [
path(path_spec, api_view),
]
urls = [path(path_spec, api_view)]
settings.ROOT_URLCONF = __name__
mocker.patch(f"{__name__}.urlpatterns", urls)

Expand All @@ -240,17 +234,13 @@ class Body(BaseModel):
def view(request: HttpRequest, body: Body) -> JsonResponse:
return JsonResponse({})

urls = [
path("", view),
]

urls = [path("", view)]
mocker.patch(f"{__name__}.urlpatterns", urls)

# No data is invalid
assert client.post("/").status_code == 400
# No content_type is invalid
assert client.post("/", data={}).status_code == 400
# Json content type is ok
# Allow empty body with empty type
assert client.post("/").status_code == 200
assert client.post("/", data={}).status_code == 200
# Allow empty dict JSON as well
assert client.post("/", data={}, content_type="application/json").status_code == 200


Expand All @@ -267,10 +257,7 @@ class Body(BaseModel):
def view(request: HttpRequest, body: Body) -> JsonResponse:
return JsonResponse({})

urls = [
path("", view),
]

urls = [path("", view)]
mocker.patch(f"{__name__}.urlpatterns", urls)

assert client.post("/", data={}, content_type="application/json").status_code == 400
Expand All @@ -294,6 +281,112 @@ def view(request: HttpRequest, body: Body) -> JsonResponse:
assert response.json()["field_errors"].keys() == {"num", "d"}


@override_settings(ROOT_URLCONF=__name__)
def test_parsing_form_encoded(client: Client, mocker: MockerFixture) -> None:
class Body(BaseModel):
num: int
d: datetime.date

@api(
method="POST",
login_required=False,
)
def view(request: HttpRequest, body: Body) -> JsonResponse:
return JsonResponse(body.model_dump(mode="json"))

urls = [path("", view)]
mocker.patch(f"{__name__}.urlpatterns", urls)

# Test missing fields
response = client.post("/", data={})
assert response.status_code == 400
assert response.json() == {
"errors": ["num: Field required", "d: Field required"],
"field_errors": {
"num": {"code": "missing", "message": "Field required"},
"d": {"code": "missing", "message": "Field required"},
},
}

response = client.post("/", data={"num": 3, "d": "2022-01-01"})
assert response.status_code == 200
assert response.json() == {"num": 3, "d": "2022-01-01"}

# Check that field errors propagate
response = client.post("/", data={"num": "x", "d": "2022-01-01"})
assert response.status_code == 400
assert response.json() == {
"errors": [
"num: Input should be a valid integer, "
"unable to parse string as an integer"
],
"field_errors": {
"num": {
"code": "int_parsing",
"message": "Input should be a valid integer, "
"unable to parse string as an integer",
}
},
}

response = client.post("/", data={"num": 1, "d": "2022-31-41"})
assert response.status_code == 400
assert response.json() == {
"errors": [
"d: Input should be a valid date or datetime, "
"month value is outside expected range of 1-12"
],
"field_errors": {
"d": {
"code": "date_from_datetime_parsing",
"message": (
"Input should be a valid date or datetime, "
"month value is outside expected range of 1-12"
),
}
},
}


@override_settings(ROOT_URLCONF=__name__)
def test_parsing_form_encoded_list(client: Client, mocker: MockerFixture) -> None:
class Body(BaseModel):
numbers: list[int]

@api(
method="POST",
login_required=False,
)
def view(request: HttpRequest, body: Body) -> JsonResponse:
return JsonResponse(body.model_dump(mode="json"))

urls = [path("", view)]
mocker.patch(f"{__name__}.urlpatterns", urls)

response = client.post("/", data={"numbers": [3, 4]})
assert response.status_code == 200
assert response.json() == {"numbers": [3, 4]}

# Check that field errors propagate
response = client.post("/", data={"numbers": "hello"})
assert response.status_code == 400
assert response.json() == {
"errors": [
"numbers.0: Input should be a valid integer, "
"unable to parse string as an integer"
],
"field_errors": {
"numbers.0": {
"code": "int_parsing",
"message": (
"Input should be a valid integer, "
"unable to parse string as an integer"
),
}
},
}


@override_settings(ROOT_URLCONF=__name__)
def test_parsing_list(client: Client, mocker: MockerFixture) -> None:
class Body(BaseModel):
Expand All @@ -307,10 +400,7 @@ class Body(BaseModel):
def view(request: HttpRequest, body: list[Body]) -> JsonResponse:
return JsonResponse({})

urls = [
path("", view),
]

urls = [path("", view)]
mocker.patch(f"{__name__}.urlpatterns", urls)

assert client.post("/", data={}, content_type="application/json").status_code == 400
Expand Down
Loading