diff --git a/django_api_decorator/decorators.py b/django_api_decorator/decorators.py index 544a49e..8f9d483 100644 --- a/django_api_decorator/decorators.py +++ b/django_api_decorator/decorators.py @@ -16,6 +16,7 @@ from pydantic_core import PydanticUndefined from .types import ApiMeta, FieldError, PublicAPIError +from .utils import get_list_fields, parse_form_encoded_body P = typing.ParamSpec("P") T = typing.TypeVar("T") @@ -101,9 +102,9 @@ def decorator(func: Callable[..., Any]) -> Callable[..., HttpResponse]: # If the method has a "body" argument, get a function to call to parse # the request body into the type expected by the view. - body_adapter = None - if "body" in signature.parameters: - body_adapter = _get_body_adapter(parameter=signature.parameters["body"]) + list_fields, body_adapter = set[str](), None + if body_annotation := signature.parameters.get("body"): + list_fields, body_adapter = _get_body_adapter(body_annotation) # Get a function to use for encoding the value returned from the view # into a request we can return to the client. @@ -126,7 +127,14 @@ def inner(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: # Parse the request body if the request method allows a body and the # view has requested that we should parse the body. if _can_have_body(request.method) and body_adapter: - extra_kwargs["body"] = body_adapter.validate_json(request.body) + if request.content_type in { + "application/x-www-form-urlencoded", + "multipart/form-data", + }: + data = parse_form_encoded_body(request, list_fields) + extra_kwargs["body"] = body_adapter.validate_python(data) + else: + extra_kwargs["body"] = body_adapter.validate_json(request.body) # Parse query params and add them to the parameters given to the view. raw_query_params: dict[str, Any] = {} @@ -261,12 +269,18 @@ def _can_have_body(method: str | None) -> bool: return method in ("POST", "PATCH", "PUT") -def _get_body_adapter(*, parameter: inspect.Parameter) -> pydantic.TypeAdapter[Any]: +def _get_body_adapter( + parameter: inspect.Parameter, +) -> tuple[set[str], pydantic.TypeAdapter[Any]]: annotation = parameter.annotation if annotation is inspect.Parameter.empty: raise TypeError("The body parameter must have a type annotation") - return pydantic.TypeAdapter(annotation) + list_fields = set() + if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel): + list_fields = get_list_fields(annotation) + + return list_fields, pydantic.TypeAdapter(annotation) ##################### diff --git a/django_api_decorator/utils.py b/django_api_decorator/utils.py index b9eb46b..97e78ed 100644 --- a/django_api_decorator/utils.py +++ b/django_api_decorator/utils.py @@ -1,6 +1,8 @@ -from typing import Callable, ParamSpec, cast +from types import UnionType +from typing import Any, Callable, List, ParamSpec, Set, Tuple, Union, cast, get_origin from django.http import HttpRequest, HttpResponse, HttpResponseNotAllowed +from pydantic import BaseModel P = ParamSpec("P") @@ -58,3 +60,33 @@ def call_view( call_view._method_router_views = views # type: ignore[attr-defined] return cast(Callable[..., HttpResponse], call_view) + + +def is_list_type(annotation: Any) -> bool: + """ + Check if the given annotation is a collection type. + """ + + origin = get_origin(annotation) + return origin in (Union, UnionType, List, Tuple, Set, list, set, tuple) + + +def get_list_fields(model: type[BaseModel]) -> set[str]: + return { + name + for name, field in model.model_fields.items() + if is_list_type(field.annotation) + } + + +def parse_form_encoded_body( + request: HttpRequest, list_fields: set[str] +) -> dict[str, str | list[str] | None]: + """ + Convert request.POST to a format pydantic can take as input + """ + + return { + key: request.POST.getlist(key) if key in list_fields else request.POST.get(key) + for key in request.POST + } diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 62f0405..f32c242 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -29,7 +29,6 @@ def post_view(request: HttpRequest) -> JsonResponse: path("get", get_view), path("post", post_view), ] - mocker.patch(f"{__name__}.urlpatterns", urls) response = client.get("/get") @@ -69,7 +68,6 @@ def noauth_anonymous_view(request: HttpRequest) -> JsonResponse: path("noauth-user", noauth_user_view), path("noauth-anonymous", noauth_anonymous_view), ] - mocker.patch(f"{__name__}.urlpatterns", urls) response = client.get("/auth-user") @@ -173,12 +171,10 @@ def test_query_params( # type: ignore mocker, client, settings, view, have_url, want_status, want_values ): collector, api_view = _create_api_view(view, ["query_param"]) # type: ignore - urls = [ - path("", api_view), - ] - + urls = [path("", api_view)] settings.ROOT_URLCONF = __name__ mocker.patch(f"{__name__}.urlpatterns", urls) + got = client.get(have_url) assert got.status_code == want_status if want_values is None: @@ -213,9 +209,7 @@ def test_path_params( # type: ignore """ collector, api_view = _create_api_view(view, None) # type: ignore - urls = [ - path(path_spec, api_view), - ] + urls = [path(path_spec, api_view)] settings.ROOT_URLCONF = __name__ mocker.patch(f"{__name__}.urlpatterns", urls) @@ -240,17 +234,13 @@ class Body(BaseModel): def view(request: HttpRequest, body: Body) -> JsonResponse: return JsonResponse({}) - urls = [ - path("", view), - ] - + urls = [path("", view)] mocker.patch(f"{__name__}.urlpatterns", urls) - # No data is invalid - assert client.post("/").status_code == 400 - # No content_type is invalid - assert client.post("/", data={}).status_code == 400 - # Json content type is ok + # Allow empty body with empty type + assert client.post("/").status_code == 200 + assert client.post("/", data={}).status_code == 200 + # Allow empty dict JSON as well assert client.post("/", data={}, content_type="application/json").status_code == 200 @@ -267,10 +257,7 @@ class Body(BaseModel): def view(request: HttpRequest, body: Body) -> JsonResponse: return JsonResponse({}) - urls = [ - path("", view), - ] - + urls = [path("", view)] mocker.patch(f"{__name__}.urlpatterns", urls) assert client.post("/", data={}, content_type="application/json").status_code == 400 @@ -294,6 +281,112 @@ def view(request: HttpRequest, body: Body) -> JsonResponse: assert response.json()["field_errors"].keys() == {"num", "d"} +@override_settings(ROOT_URLCONF=__name__) +def test_parsing_form_encoded(client: Client, mocker: MockerFixture) -> None: + class Body(BaseModel): + num: int + d: datetime.date + + @api( + method="POST", + login_required=False, + ) + def view(request: HttpRequest, body: Body) -> JsonResponse: + return JsonResponse(body.model_dump(mode="json")) + + urls = [path("", view)] + mocker.patch(f"{__name__}.urlpatterns", urls) + + # Test missing fields + response = client.post("/", data={}) + assert response.status_code == 400 + assert response.json() == { + "errors": ["num: Field required", "d: Field required"], + "field_errors": { + "num": {"code": "missing", "message": "Field required"}, + "d": {"code": "missing", "message": "Field required"}, + }, + } + + response = client.post("/", data={"num": 3, "d": "2022-01-01"}) + assert response.status_code == 200 + assert response.json() == {"num": 3, "d": "2022-01-01"} + + # Check that field errors propagate + response = client.post("/", data={"num": "x", "d": "2022-01-01"}) + assert response.status_code == 400 + assert response.json() == { + "errors": [ + "num: Input should be a valid integer, " + "unable to parse string as an integer" + ], + "field_errors": { + "num": { + "code": "int_parsing", + "message": "Input should be a valid integer, " + "unable to parse string as an integer", + } + }, + } + + response = client.post("/", data={"num": 1, "d": "2022-31-41"}) + assert response.status_code == 400 + assert response.json() == { + "errors": [ + "d: Input should be a valid date or datetime, " + "month value is outside expected range of 1-12" + ], + "field_errors": { + "d": { + "code": "date_from_datetime_parsing", + "message": ( + "Input should be a valid date or datetime, " + "month value is outside expected range of 1-12" + ), + } + }, + } + + +@override_settings(ROOT_URLCONF=__name__) +def test_parsing_form_encoded_list(client: Client, mocker: MockerFixture) -> None: + class Body(BaseModel): + numbers: list[int] + + @api( + method="POST", + login_required=False, + ) + def view(request: HttpRequest, body: Body) -> JsonResponse: + return JsonResponse(body.model_dump(mode="json")) + + urls = [path("", view)] + mocker.patch(f"{__name__}.urlpatterns", urls) + + response = client.post("/", data={"numbers": [3, 4]}) + assert response.status_code == 200 + assert response.json() == {"numbers": [3, 4]} + + # Check that field errors propagate + response = client.post("/", data={"numbers": "hello"}) + assert response.status_code == 400 + assert response.json() == { + "errors": [ + "numbers.0: Input should be a valid integer, " + "unable to parse string as an integer" + ], + "field_errors": { + "numbers.0": { + "code": "int_parsing", + "message": ( + "Input should be a valid integer, " + "unable to parse string as an integer" + ), + } + }, + } + + @override_settings(ROOT_URLCONF=__name__) def test_parsing_list(client: Client, mocker: MockerFixture) -> None: class Body(BaseModel): @@ -307,10 +400,7 @@ class Body(BaseModel): def view(request: HttpRequest, body: list[Body]) -> JsonResponse: return JsonResponse({}) - urls = [ - path("", view), - ] - + urls = [path("", view)] mocker.patch(f"{__name__}.urlpatterns", urls) assert client.post("/", data={}, content_type="application/json").status_code == 400