diff --git a/django_structlog/middlewares/request.py b/django_structlog/middlewares/request.py index e0ac48d1..59e894a5 100644 --- a/django_structlog/middlewares/request.py +++ b/django_structlog/middlewares/request.py @@ -21,6 +21,7 @@ from django.core.exceptions import PermissionDenied from django.core.signals import got_request_exception from django.http import Http404, StreamingHttpResponse +from django.utils.functional import SimpleLazyObject from .. import signals from ..app_settings import app_settings @@ -39,6 +40,7 @@ if TYPE_CHECKING: # pragma: no cover from types import TracebackType + from django.contrib.auth.base_user import AbstractBaseUser from django.http import HttpRequest, HttpResponse logger = structlog.getLogger(__name__) @@ -209,7 +211,14 @@ def format_request(request: "HttpRequest") -> str: @staticmethod def bind_user_id(request: "HttpRequest") -> None: user_id_field = app_settings.USER_ID_FIELD - if hasattr(request, "user") and request.user is not None and user_id_field: + if not user_id_field or not hasattr(request, "user"): + return + + session_was_accessed = ( + request.session.accessed if hasattr(request, "session") else None + ) + + if request.user is not None: user_id = None if hasattr(request.user, user_id_field): user_id = getattr(request.user, user_id_field) @@ -217,6 +226,17 @@ def bind_user_id(request: "HttpRequest") -> None: user_id = str(user_id) structlog.contextvars.bind_contextvars(user_id=user_id) + if session_was_accessed is False: + """using SessionMiddleware but user was never accessed, must reset accessed state""" + user = request.user + + def get_user() -> Any: + request.session.accessed = True + return user + + request.user = cast("AbstractBaseUser", SimpleLazyObject(get_user)) + request.session.accessed = False + def process_got_request_exception( self, sender: Type[Any], request: "HttpRequest", **kwargs: Any ) -> None: diff --git a/test_app/tests/middlewares/test_request.py b/test_app/tests/middlewares/test_request.py index d8088b68..d7cfc0b9 100644 --- a/test_app/tests/middlewares/test_request.py +++ b/test_app/tests/middlewares/test_request.py @@ -7,7 +7,9 @@ from unittest.mock import AsyncMock, Mock, patch import structlog +from django.contrib.auth.middleware import AuthenticationMiddleware from django.contrib.auth.models import AnonymousUser, User +from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sites.models import Site from django.contrib.sites.shortcuts import get_current_site from django.core.exceptions import PermissionDenied @@ -236,6 +238,97 @@ class SimpleUser: self.assertIn("user_id", record.msg) self.assertIsNone(record.msg["user_id"]) + @override_settings( + SECRET_KEY="00000000000000000000000000000000", + ) + def test_process_request_session_middleware_without_vary(self) -> None: + def get_response(_request: HttpRequest) -> HttpResponse: + with self.assertLogs(__name__, logging.INFO) as log_results: + self.logger.info("hello") + self.log_results = log_results + return HttpResponse() + + request = self.factory.get("/foo") + + # simulate SessionMiddleware, AuthenticationMiddleware, and RequestMiddleware called in that order + request_middleware = RequestMiddleware(get_response) + authentication_middleware = AuthenticationMiddleware( + cast( + Any, + lambda r: request_middleware(r), + ) + ) + session_middleware = SessionMiddleware( + cast(Any, lambda r: authentication_middleware(r)) + ) + response = session_middleware(request) + + self.assertEqual(1, len(self.log_results.records)) + record = self.log_results.records[0] + self.assertIsNone(cast(HttpResponse, response).headers.get("Vary")) + + self.assertEqual("INFO", record.levelname) + + self.assertIn("user_id", record.msg) + self.assertIsNone(record.msg["user_id"]) + + @override_settings( + SECRET_KEY="00000000000000000000000000000000", + ) + def test_process_request_session_middleware_with_vary(self) -> None: + def get_response(_request: HttpRequest) -> HttpResponse: + assert isinstance( + request.user, AnonymousUser + ) # force evaluate user to trigger session middleware + with self.assertLogs(__name__, logging.INFO) as log_results: + self.logger.info("hello") + self.log_results = log_results + return HttpResponse() + + request = self.factory.get("/foo") + + # simulate SessionMiddleware, AuthenticationMiddleware, and RequestMiddleware called in that order + request_middleware = RequestMiddleware(get_response) + authentication_middleware = AuthenticationMiddleware( + cast(Any, lambda r: request_middleware(r)) + ) + session_middleware = SessionMiddleware( + cast(Any, lambda r: authentication_middleware(r)) + ) + response = session_middleware(request) + + self.assertEqual(1, len(self.log_results.records)) + record = self.log_results.records[0] + self.assertIsNotNone(cast(HttpResponse, response).headers.get("Vary")) + + self.assertEqual("INFO", record.levelname) + + self.assertIn("user_id", record.msg) + self.assertIsNone(record.msg["user_id"]) + + @override_settings( + DJANGO_STRUCTLOG_USER_ID_FIELD=None, + ) + def test_process_request_no_user_id_field(self) -> None: + def get_response(_request: HttpRequest) -> HttpResponse: + with self.assertLogs(__name__, logging.INFO) as log_results: + self.logger.info("hello") + self.log_results = log_results + return HttpResponse() + + request = self.factory.get("/foo") + + middleware = RequestMiddleware(get_response) + response = middleware(request) + self.assertEqual(200, cast(HttpResponse, response).status_code) + + self.assertEqual(1, len(self.log_results.records)) + record = self.log_results.records[0] + + self.assertEqual("INFO", record.levelname) + + self.assertNotIn("user_id", record.msg) + def test_log_user_in_request_finished(self) -> None: mock_response = Mock() mock_response.status_code = 200