Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ http exceptions and default exception handlers #9

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
".coverage": true,

// 🗑️
".patches": true,
// ".patches": true,
".venv": true,
".vscode": true,
".ruff_cache": true,
Expand Down
9 changes: 3 additions & 6 deletions example/app/data/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from typing import List

from pest.utils.functions import dump_model

from ..modules.todo.models.todo import TodoCreate, TodoModel

default_todos = [
Expand Down Expand Up @@ -38,12 +40,7 @@ def next_id() -> int:
return max([todo.id for todo in self.todos]) + 1

new_todo = TodoModel(
**(
# HACK: model_dump was added in pydantic@2 and will replace dict()
# from @3x onwards; this is a temporary hack to support both
# versions until pydantic@2x is widely adopted
todo.model_dump() if hasattr(todo, 'model_dump') else todo.dict()
),
**(dump_model(todo)),
id=next_id(),
done=False,
)
Expand Down
10 changes: 10 additions & 0 deletions example/app/modules/todo/services/todo_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import fastapi

from pest.exceptions.http.http import ForbiddenException

from ....data.data import TodoRepo
from ..models.todo import TodoCreate, TodoModel

Expand All @@ -20,6 +22,14 @@ def get_all(self) -> list[TodoModel]:
return self.repo.get_all()

def get(self, id: int) -> TodoModel:
if (id == 10):
# nestjs style http exception raising
raise ForbiddenException('Number 10 is reserved for Lionel Messi only')

if (id == 11):
# to test default exception handler
raise Exception('This is a test exception')

return validate(self.repo.get_by_id(id))

def create(self, todo: TodoCreate) -> TodoModel:
Expand Down
22 changes: 21 additions & 1 deletion pest/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@

from fastapi import FastAPI, Response, routing
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.params import Depends
from fastapi.responses import JSONResponse
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import generate_unique_id
from pydantic import ValidationError
from rodi import ActivationScope
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.routing import BaseRoute
from typing_extensions import Doc

from pest.logging import log
from pest.middleware.types import CorsOptions

from ..exceptions import handle
from ..metadata.types.module_meta import InjectionToken
from ..middleware.base import (
PestBaseHTTPMiddleware,
Expand Down Expand Up @@ -57,7 +61,23 @@ def __init__(
for middleware in middleware
]
)
print('ok')

self.add_exception_handlers([
(HTTPException, handle.http),
(ValidationError, handle.request_validation),
(RequestValidationError, handle.request_validation),
(WebSocketRequestValidationError, handle.websocket_request_validation),

# for everything else, there's Mastercard (or was it Bancard? 🤔)
(Exception, handle.the_rest),
])

def add_exception_handlers(
self,
handlers: list[tuple[int | type[Exception], Callable]]
) -> None: # pragma: no cover
for error, handler in handlers:
self.add_exception_handler(error, handler)

def __str__(self) -> str:
return str(root_module(self))
Expand Down
2 changes: 1 addition & 1 deletion pest/core/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pest.metadata.types._meta import PestType

from ...utils.exceptions.base import PestException
from ...exceptions.base.pest import PestException
from ..types.status import Status


Expand Down
2 changes: 1 addition & 1 deletion pest/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from fastapi.routing import APIRoute

from ..exceptions.base.pest import PestException
from ..metadata.meta import get_meta, get_meta_value, inject_metadata
from ..metadata.types._meta import PestType
from ..metadata.types.controller_meta import ControllerMeta
from ..metadata.types.handler_meta import HandlerMeta
from ..utils.exceptions.base import PestException
from ..utils.fastapi.router import PestRouter
from .common import PestPrimitive
from .handler import HandlerTuple, setup_handler
Expand Down
2 changes: 1 addition & 1 deletion pest/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pest.utils.module as module_utils
from pest.metadata.types._meta import PestType

from ..exceptions.base.pest import PestException
from ..metadata.meta import get_meta
from ..metadata.types.module_meta import (
ClassProvider,
Expand All @@ -16,7 +17,6 @@
Provider,
ValueProvider,
)
from ..utils.exceptions.base import PestException
from .common import PestPrimitive
from .controller import Controller, router_of, setup_controller
from .types.status import Status
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@

import re

from ..colorize import c
from ...utils.colorize import c


class PestException(Exception):
"""Base class for all exceptions raised by Pest."""
"""Base class for all exceptions raised by Pest during setup

This exception is not meant to be used outside of pest's internals, since it's not
handled by any of the asgi exception handlers.
"""

def __init__(
self,
Expand Down
78 changes: 78 additions & 0 deletions pest/exceptions/handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code
from fastapi.websockets import WebSocket
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import WS_1008_POLICY_VIOLATION

from ..utils.functions import dump_model
from .http.http import ExceptionResponse, PestHTTPException
from .http.status import HTTPStatusEnum, http_status


async def http(request: Request, exc: HTTPException) -> Response:
"""handles both `HTTPException` (fastapi) and `PestHTTPException` (pest)"""

headers = getattr(exc, 'headers', None)
if not is_body_allowed_for_status_code(exc.status_code):
return Response(status_code=exc.status_code, headers=headers)

if isinstance(exc, PestHTTPException):
content = vars(exc)
else:
content = vars(
PestHTTPException(
status_code=exc.status_code,
detail=exc.detail,
headers=headers,
)
)

return JSONResponse(
status_code=exc.status_code,
content=content,
headers=headers,
)


async def request_validation(
request: Request, exc: ValidationError | RequestValidationError
) -> JSONResponse:
"""handles both `RequestValidationError` (fastapi) and `ValidationError` (pydantic)"""

stat = http_status(HTTPStatusEnum.HTTP_400_BAD_REQUEST)

messages = [
f"{', '.join([str(elem) for elem in err['loc']])}: {err['msg'].capitalize()}"
for err in exc.errors()
]

return JSONResponse(
status_code=stat.code,
content=vars(ExceptionResponse(code=stat.code, error=stat.phrase, message=messages)),
)


async def websocket_request_validation(
websocket: WebSocket, exc: WebSocketRequestValidationError
) -> None:
await websocket.close(code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors()))


async def the_rest(request: Request, exc: Exception) -> Response:
"""handles all other exceptions

returns a generic 500 response and logs the exception message, just in case it
was an
"""

stat = http_status(HTTPStatusEnum.HTTP_500_INTERNAL_SERVER_ERROR)

return JSONResponse(
status_code=stat.code, content=dump_model(
ExceptionResponse(code=stat.code, error=stat.phrase)
)
)
Loading