Skip to content

Commit

Permalink
Fix issue where CORS headers were not being returned properly
Browse files Browse the repository at this point in the history
  • Loading branch information
suecharo committed Oct 8, 2024
1 parent 8c8d672 commit 41da3a9
Showing 1 changed file with 50 additions and 2 deletions.
52 changes: 50 additions & 2 deletions sapporo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.datastructures import Headers, MutableHeaders
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from sapporo.auth import get_auth_config
from sapporo.config import (LOGGER, PKG_DIR, add_openapi_info, get_config,
Expand Down Expand Up @@ -59,6 +62,52 @@ async def generic_exception_handler(_request: Request, _exc: Exception) -> JSONR
)


class CustomCORSMiddleware(CORSMiddleware):
"""\
CORSMiddleware that returns CORS headers even if the Origin header is not present
"""

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

method = scope["method"]
headers = Headers(scope=scope)

if method == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return

await self.simple_response(scope, receive, send, request_headers=headers)

async def send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
if message["type"] != "http.response.start":
await send(message)
return

message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
origin = request_headers.get("Origin", "*")
has_cookie = "cookie" in request_headers

# If request includes any cookie headers, then we must respond
# with the specific origin instead of '*'.
if self.allow_all_origins and has_cookie:
self.allow_explicit_origin(headers, origin)

# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
self.allow_explicit_origin(headers, origin)

await send(message)


def init_app_state() -> None:
"""
Perform validation, initialize the cache, and log the configuration contents.
Expand Down Expand Up @@ -146,9 +195,8 @@ def create_app() -> FastAPI:
)

app.add_middleware(
CORSMiddleware,
CustomCORSMiddleware,
allow_origins=[app_config.allow_origin],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
Expand Down

0 comments on commit 41da3a9

Please sign in to comment.