From 89f2645edfdec01ae7b12cea2ec6163405803d73 Mon Sep 17 00:00:00 2001 From: Kiran Jonnalagadda Date: Fri, 8 Dec 2023 10:51:49 +0530 Subject: [PATCH] Fix incorrect attr for user in LoginSession (missed in #1697) (#1938) Also fix/update imports and typing. --- funnel/models/login_session.py | 4 ++++ funnel/views/login_session.py | 2 +- tests/conftest.py | 2 +- tests/unit/utils/markdown/conftest.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/funnel/models/login_session.py b/funnel/models/login_session.py index 73c8bee53..384fbd872 100644 --- a/funnel/models/login_session.py +++ b/funnel/models/login_session.py @@ -34,6 +34,10 @@ class LoginSessionError(Exception): """Base exception for user session errors.""" + def __init__(self, login_session: LoginSession, *args) -> None: + self.login_session = login_session + super().__init__(login_session, *args) + class LoginSessionExpiredError(LoginSessionError): """This user session has expired and cannot be marked as currently active.""" diff --git a/funnel/views/login_session.py b/funnel/views/login_session.py index 7ab5f3086..294c1be59 100644 --- a/funnel/views/login_session.py +++ b/funnel/views/login_session.py @@ -174,7 +174,7 @@ def _load_user(): # TODO: Force render of logout page to clear client-side data logout_internal() except LoginSessionInactiveUserError as exc: - inactive_user = exc.args[0].user + inactive_user = exc.login_session.account if inactive_user.state.SUSPENDED: flash(_("Your account has been suspended")) elif inactive_user.state.DELETED: diff --git a/tests/conftest.py b/tests/conftest.py index 2c4948f5d..2444960a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,7 +90,7 @@ def sort_key(item: pytest.Function) -> tuple[int, str]: # as item.location == (file_path, line_no, function_name). However, pytest-bdd # reports itself for file_path, so we can't use that and must extract the path # from the test module instead - module_file = item.module.__file__ + module_file = item.module.__file__ if item.module is not None else '' for counter, path in enumerate(test_order): if path in module_file: return (counter, module_file) diff --git a/tests/unit/utils/markdown/conftest.py b/tests/unit/utils/markdown/conftest.py index c3dda32bb..feb8569e8 100644 --- a/tests/unit/utils/markdown/conftest.py +++ b/tests/unit/utils/markdown/conftest.py @@ -107,7 +107,7 @@ def dump(cls) -> None: if cls.test_map is not None: for md_testname, data in cls.test_files.items(): data['expected_output'] = { - md_configname: tomlkit.api.string(case.output, multiline=True) + md_configname: tomlkit.string(case.output, multiline=True) for md_configname, case in cls.test_map[md_testname].items() } (md_tests_data_root / md_testname).write_text(tomlkit.dumps(data))