diff --git a/funnel/views/helpers.py b/funnel/views/helpers.py index d003cc87f..e4185e9f9 100644 --- a/funnel/views/helpers.py +++ b/funnel/views/helpers.py @@ -6,13 +6,12 @@ import zlib import zoneinfo from base64 import urlsafe_b64encode -from collections.abc import Callable +from collections.abc import Callable, Mapping from contextlib import nullcontext from datetime import datetime, timedelta from hashlib import blake2b from importlib import resources from os import urandom -from typing import Any from urllib.parse import quote, unquote, urljoin, urlsplit import brotli @@ -93,14 +92,36 @@ def __delitem__(self, key: str) -> None: self.keys_at.remove(f'{key}_at') super().__delitem__(key) - def has_intersection(self, other: Any) -> bool: + def has_overlap_with(self, other: Mapping) -> bool: """Check for intersection with other dictionary-like object.""" okeys = other.keys() return not (self.keys_at.isdisjoint(okeys) and self.keys().isdisjoint(okeys)) + def crosscheck_session(self, response: ResponseType) -> ResponseType: + """Add timestamps to timed values in session, and remove expired values.""" + # Process timestamps only if there is at least one match. Most requests will + # have no match. + if self.has_overlap_with(session): + now = utcnow() + for var, delta in self.items(): + var_at = f'{var}_at' + if var in session: + if var_at not in session: + # Session has var but not timestamp, so add a timestamp + session[var_at] = now + elif session[var_at] < now - delta: + # Session var has expired, so remove var and timestamp + session.pop(var) + session.pop(var_at) + elif var_at in session: + # Timestamp present without var, so remove it + session.pop(var_at) + return response + #: Temporary values that must be periodically expunged from the cookie session session_timeouts = SessionTimeouts() +app.after_request(session_timeouts.crosscheck_session) # --- Utilities ------------------------------------------------------------------------ @@ -626,29 +647,6 @@ def commit_db_session(response: ResponseType) -> ResponseType: return response -@app.after_request -def track_temporary_session_vars(response: ResponseType) -> ResponseType: - """Add timestamps to timed values in session, and remove expired values.""" - # Process timestamps only if there is at least one match. Most requests will - # have no match. - if session_timeouts.has_intersection(session): - for var, delta in session_timeouts.items(): - var_at = f'{var}_at' - if var in session: - if var_at not in session: - # Session has var but not timestamp, so add a timestamp - session[var_at] = utcnow() - elif session[var_at] < utcnow() - delta: - # Session var has expired, so remove var and timestamp - session.pop(var) - session.pop(var_at) - elif var_at in session: - # Timestamp present without var, so remove it - session.pop(var_at) - - return response - - @app.after_request def cache_expiry_headers(response: ResponseType) -> ResponseType: if response.expires is None: diff --git a/tests/unit/views/session_temp_vars_test.py b/tests/unit/views/session_temp_vars_test.py index 2ceed93ea..0963bec9f 100644 --- a/tests/unit/views/session_temp_vars_test.py +++ b/tests/unit/views/session_temp_vars_test.py @@ -47,8 +47,8 @@ def test_session_intersection() -> None: fake_session_intersection = {'test': 'value', 'other': 'other_value'} fake_session_disjoint = {'other': 'other_value', 'yet_other': 'yet_other_value'} - assert st.has_intersection(fake_session_intersection) - assert not st.has_intersection(fake_session_disjoint) + assert st.has_overlap_with(fake_session_intersection) + assert not st.has_overlap_with(fake_session_disjoint) @pytest.fixture()