diff --git a/pyproject.toml b/pyproject.toml index a476ae3489..0ffaa6852f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ module = [ "palace.manager.api.metadata.*", "palace.manager.api.odl.*", "palace.manager.api.opds_for_distributors", + "palace.manager.api.util.flask", "palace.manager.core.opds2_import", "palace.manager.core.opds_import", "palace.manager.core.opds_schema", diff --git a/src/palace/manager/api/admin/controller/admin_search.py b/src/palace/manager/api/admin/controller/admin_search.py index 295f12d72a..9ab0a2f96a 100644 --- a/src/palace/manager/api/admin/controller/admin_search.py +++ b/src/palace/manager/api/admin/controller/admin_search.py @@ -1,9 +1,9 @@ from __future__ import annotations -import flask from sqlalchemy import func, or_ from palace.manager.api.admin.controller.base import AdminController +from palace.manager.api.util.flask import get_request_library from palace.manager.sqlalchemy.model.classification import ( Classification, Genre, @@ -30,7 +30,7 @@ def search_field_values(self) -> dict: - Publisher - Subject """ - library: Library = flask.request.library # type: ignore + library = get_request_library() collection_ids = [coll.id for coll in library.collections if coll.id] return self._search_field_values_cached(collection_ids) diff --git a/src/palace/manager/api/admin/controller/custom_lists.py b/src/palace/manager/api/admin/controller/custom_lists.py index 04de13dc93..4237432a90 100644 --- a/src/palace/manager/api/admin/controller/custom_lists.py +++ b/src/palace/manager/api/admin/controller/custom_lists.py @@ -23,6 +23,7 @@ CirculationManagerController, ) from palace.manager.api.problem_details import CANNOT_DELETE_SHARED_LIST +from palace.manager.api.util.flask import get_request_library from palace.manager.core.app_server import load_pagination_from_request from palace.manager.core.problem_details import INVALID_INPUT, METHOD_NOT_ALLOWED from palace.manager.core.query.customlist import CustomListQueries @@ -83,7 +84,7 @@ def _list_as_json(self, list: CustomList, is_owner=True) -> dict: ) def custom_lists(self) -> dict | ProblemDetail | Response | None: - library: Library = flask.request.library # type: ignore # "Request" has no attribute "library" + library = get_request_library() self.require_librarian(library) if flask.request.method == "GET": @@ -322,7 +323,7 @@ def url_fn(after): return url_fn def custom_list(self, list_id: int) -> Response | dict | ProblemDetail | None: - library: Library = flask.request.library # type: ignore + library = get_request_library() self.require_librarian(library) data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF) @@ -375,7 +376,7 @@ def custom_list(self, list_id: int) -> Response | dict | ProblemDetail | None: elif flask.request.method == "DELETE": # Deleting requires a library manager. - self.require_library_manager(flask.request.library) # type: ignore + self.require_library_manager(get_request_library()) if len(list.shared_locally_with_libraries) > 0: return CANNOT_DELETE_SHARED_LIST @@ -413,7 +414,7 @@ def share_locally( customlist = get_one(self._db, CustomList, id=customlist_id) if not customlist: return MISSING_CUSTOM_LIST - if customlist.library != flask.request.library: # type: ignore + if customlist.library != get_request_library(): return ADMIN_NOT_AUTHORIZED.detailed( _("This library does not have permissions on this customlist.") ) diff --git a/src/palace/manager/api/admin/controller/dashboard.py b/src/palace/manager/api/admin/controller/dashboard.py index 4624ceaf21..5115e6b47b 100644 --- a/src/palace/manager/api/admin/controller/dashboard.py +++ b/src/palace/manager/api/admin/controller/dashboard.py @@ -12,6 +12,7 @@ CirculationManagerController, ) from palace.manager.api.local_analytics_exporter import LocalAnalyticsExporter +from palace.manager.api.util.flask import get_request_library from palace.manager.feed.annotator.admin import AdminAnnotator from palace.manager.sqlalchemy.model.admin import Admin from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent @@ -29,7 +30,7 @@ def stats( return stats_function(admin, self._db) def circulation_events(self): - annotator = AdminAnnotator(self.circulation, flask.request.library) + annotator = AdminAnnotator(self.circulation, get_request_library()) num = min(int(flask.request.args.get("num", "100")), 500) results = ( @@ -92,7 +93,7 @@ def get_date(field): date_end_label = get_date("dateEnd") date_end = date_end_label + timedelta(days=1) locations = flask.request.args.get("locations", None) - library = getattr(flask.request, "library", None) + library = get_request_library(default=None) library_short_name = library.short_name if library else None analytics_exporter = analytics_exporter or LocalAnalyticsExporter() diff --git a/src/palace/manager/api/admin/controller/lanes.py b/src/palace/manager/api/admin/controller/lanes.py index 6df13b1a6d..eafe7a906e 100644 --- a/src/palace/manager/api/admin/controller/lanes.py +++ b/src/palace/manager/api/admin/controller/lanes.py @@ -20,6 +20,7 @@ CirculationManagerController, ) from palace.manager.api.lanes import create_default_lanes +from palace.manager.api.util.flask import get_request_library from palace.manager.sqlalchemy.model.customlist import CustomList from palace.manager.sqlalchemy.model.lane import Lane from palace.manager.sqlalchemy.model.library import Library @@ -28,7 +29,7 @@ class LanesController(CirculationManagerController, AdminPermissionsControllerMixin): def lanes(self): - library = flask.request.library + library = get_request_library() self.require_librarian(library) if flask.request.method == "GET": @@ -56,7 +57,7 @@ def lanes_for_parent(parent): return dict(lanes=lanes_for_parent(None)) if flask.request.method == "POST": - self.require_library_manager(flask.request.library) + self.require_library_manager(get_request_library()) id = flask.request.form.get("id") parent_id = flask.request.form.get("parent_id") @@ -173,7 +174,7 @@ def lanes_for_parent(parent): def lane(self, lane_identifier): if flask.request.method == "DELETE": - library = flask.request.library + library = get_request_library() self.require_library_manager(library) lane = get_one(self._db, Lane, id=lane_identifier, library=library) @@ -192,7 +193,7 @@ def delete_lane_and_sublanes(lane): return Response(str(_("Deleted")), 200) def show_lane(self, lane_identifier): - library = flask.request.library + library = get_request_library() self.require_library_manager(library) lane = get_one(self._db, Lane, id=lane_identifier, library=library) @@ -204,7 +205,7 @@ def show_lane(self, lane_identifier): return Response(str(_("Success")), 200) def hide_lane(self, lane_identifier): - library = flask.request.library + library = get_request_library() self.require_library_manager(library) lane = get_one(self._db, Lane, id=lane_identifier, library=library) @@ -214,13 +215,14 @@ def hide_lane(self, lane_identifier): return Response(str(_("Success")), 200) def reset(self): - self.require_library_manager(flask.request.library) + library = get_request_library() + self.require_library_manager(library) - create_default_lanes(self._db, flask.request.library) + create_default_lanes(self._db, library) return Response(str(_("Success")), 200) def change_order(self): - self.require_library_manager(flask.request.library) + self.require_library_manager(get_request_library()) submitted_lanes = json.loads(flask.request.data) diff --git a/src/palace/manager/api/admin/controller/work_editor.py b/src/palace/manager/api/admin/controller/work_editor.py index 834d9d1283..1ef68f412e 100644 --- a/src/palace/manager/api/admin/controller/work_editor.py +++ b/src/palace/manager/api/admin/controller/work_editor.py @@ -27,6 +27,7 @@ LIBRARY_NOT_FOUND, REMOTE_INTEGRATION_FAILED, ) +from palace.manager.api.util.flask import get_request_library from palace.manager.core.classifier import NO_NUMBER, NO_VALUE, genres from palace.manager.core.classifier.simplified import SimplifiedGenreClassifier from palace.manager.feed.acquisition import OPDSAcquisitionFeed @@ -62,13 +63,14 @@ def details(self, identifier_type, identifier): :return: An OPDSEntryResponse """ - self.require_librarian(flask.request.library) + library = get_request_library() + self.require_librarian(library) - work = self.load_work(flask.request.library, identifier_type, identifier) + work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work - annotator = AdminAnnotator(self.circulation, flask.request.library) + annotator = AdminAnnotator(self.circulation, library) # single_entry returns an OPDSEntryResponse that will not be # cached, which is perfect. We want the admin interface @@ -136,7 +138,8 @@ def rights_status(self): def edit(self, identifier_type, identifier): """Edit a work's metadata.""" - self.require_librarian(flask.request.library) + library = get_request_library() + self.require_librarian(library) # TODO: It would be nice to use the metadata layer for this, but # this code handles empty values differently than other metadata @@ -145,7 +148,7 @@ def edit(self, identifier_type, identifier): # db so that it can overrule other data sources that set a value, # unlike other sources which set empty fields to None. - work = self.load_work(flask.request.library, identifier_type, identifier) + work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work @@ -372,7 +375,7 @@ def suppress( ) -> Response | ProblemDetail: """Suppress a book at the level of a library.""" - library: Library | None = getattr(flask.request, "library") + library: Library | None = get_request_library(default=None) if library is None: raise ProblemDetailException(LIBRARY_NOT_FOUND) @@ -404,7 +407,7 @@ def unsuppress( ) -> Response | ProblemDetail: """Remove a book suppression from a book at the level of a library""" - library: Library | None = getattr(flask.request, "library") + library: Library | None = get_request_library(default=None) if library is None: raise ProblemDetailException(LIBRARY_NOT_FOUND) @@ -433,9 +436,10 @@ def unsuppress( def refresh_metadata(self, identifier_type, identifier, provider=None): """Refresh the metadata for a book from the content server""" - self.require_librarian(flask.request.library) + library = get_request_library() + self.require_librarian(library) - work = self.load_work(flask.request.library, identifier_type, identifier) + work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work @@ -464,9 +468,10 @@ def refresh_metadata(self, identifier_type, identifier, provider=None): def classifications(self, identifier_type, identifier): """Return list of this work's classifications.""" - self.require_librarian(flask.request.library) + library = get_request_library() + self.require_librarian(library) - work = self.load_work(flask.request.library, identifier_type, identifier) + work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work @@ -502,9 +507,10 @@ def classifications(self, identifier_type, identifier): def edit_classifications(self, identifier_type, identifier): """Edit a work's audience, target age, fiction status, and genres.""" - self.require_librarian(flask.request.library) + library = get_request_library() + self.require_librarian(library) - work = self.load_work(flask.request.library, identifier_type, identifier) + work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work @@ -670,9 +676,8 @@ def edit_classifications(self, identifier_type, identifier): return Response("", 200) def custom_lists(self, identifier_type, identifier): - self.require_librarian(flask.request.library) - - library = flask.request.library + library = get_request_library() + self.require_librarian(library) work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work diff --git a/src/palace/manager/api/authenticator.py b/src/palace/manager/api/authenticator.py index 22d46e0113..6737a4d38d 100644 --- a/src/palace/manager/api/authenticator.py +++ b/src/palace/manager/api/authenticator.py @@ -33,6 +33,7 @@ UNKNOWN_SAML_PROVIDER, UNSUPPORTED_AUTHENTICATION_MECHANISM, ) +from palace.manager.api.util.flask import get_request_library from palace.manager.core.user_profile import ProfileController from palace.manager.integration.goals import Goals from palace.manager.service.analytics.analytics import Analytics @@ -129,7 +130,7 @@ def __init__( @property def current_library_short_name(self): - return flask.request.library.short_name + return get_request_library().short_name def populate_authenticators( self, _db, libraries: Iterable[Library], analytics: Analytics | None diff --git a/src/palace/manager/api/circulation.py b/src/palace/manager/api/circulation.py index 21b621ab5c..bd1a8a07a5 100644 --- a/src/palace/manager/api/circulation.py +++ b/src/palace/manager/api/circulation.py @@ -35,6 +35,7 @@ PatronHoldLimitReached, PatronLoanLimitReached, ) +from palace.manager.api.util.flask import get_request_library from palace.manager.api.util.patron import PatronUtility from palace.manager.core.exceptions import PalaceValueError from palace.manager.integration.base import HasLibraryIntegrationConfiguration @@ -948,12 +949,12 @@ def _collect_event( if patron: # The library of the patron who caused the event. library = patron.library - elif flask.request and getattr(flask.request, "library", None): - # The library associated with the current request. - library = getattr(flask.request, "library") else: - # The library associated with the CirculationAPI itself. - library = self.library + # The library associated with the current request, defaulting to + # the library associated with the CirculationAPI itself if we are + # outside a request context, or if the request context does not + # have a library associated with it. + library = get_request_library(default=self.library) neighborhood = None if ( diff --git a/src/palace/manager/api/circulation_manager.py b/src/palace/manager/api/circulation_manager.py index 95c9074925..1a2c781eee 100644 --- a/src/palace/manager/api/circulation_manager.py +++ b/src/palace/manager/api/circulation_manager.py @@ -32,6 +32,7 @@ from palace.manager.api.lanes import load_lanes from palace.manager.api.problem_details import NO_SUCH_LANE from palace.manager.api.saml.controller import SAMLController +from palace.manager.api.util.flask import get_request_library from palace.manager.core.app_server import ( ApplicationVersionController, load_facets_from_request, @@ -392,7 +393,7 @@ def annotator(self, lane, facets=None, *args, **kwargs): elif lane and isinstance(lane, WorkList): library = lane.get_library(self._db) if not library and hasattr(flask.request, "library"): - library = flask.request.library + library = get_request_library() # If no library is provided, the best we can do is a generic # annotator for this application. @@ -434,7 +435,7 @@ def authentication_for_opds_document(self): internal details of deployment, it should only be enabled when diagnosing deployment problems. """ - name = flask.request.library.short_name + name = get_request_library().short_name value = self.authentication_for_opds_documents.get(name, None) if value is None: # The document was not in the cache, either because it's diff --git a/src/palace/manager/api/controller/analytics.py b/src/palace/manager/api/controller/analytics.py index 266ac153a0..ba1022a7d2 100644 --- a/src/palace/manager/api/controller/analytics.py +++ b/src/palace/manager/api/controller/analytics.py @@ -7,6 +7,7 @@ CirculationManagerController, ) from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE +from palace.manager.api.util.flask import get_request_library from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.problem_detail import ProblemDetail @@ -18,7 +19,7 @@ def track_event(self, identifier_type, identifier, event_type): # a way to distinguish between different LicensePools for the # same book. if event_type in CirculationEvent.CLIENT_EVENTS: - library = flask.request.library + library = get_request_library() # Authentication on the AnalyticsController is optional, # so flask.request.patron may or may not be set. patron = getattr(flask.request, "patron", None) diff --git a/src/palace/manager/api/controller/base.py b/src/palace/manager/api/controller/base.py index 951bc7e3d4..d124380795 100644 --- a/src/palace/manager/api/controller/base.py +++ b/src/palace/manager/api/controller/base.py @@ -123,5 +123,5 @@ def library_for_request( if not library: return LIBRARY_NOT_FOUND - flask.request.library = library # type: ignore[attr-defined] + setattr(flask.request, "library", library) return library diff --git a/src/palace/manager/api/controller/circulation_manager.py b/src/palace/manager/api/controller/circulation_manager.py index 2f69de8042..1a1401ea37 100644 --- a/src/palace/manager/api/controller/circulation_manager.py +++ b/src/palace/manager/api/controller/circulation_manager.py @@ -2,7 +2,6 @@ from typing import TypeVar -import flask from flask_babel import lazy_gettext as _ from sqlalchemy import select from sqlalchemy.orm import Session, eagerload @@ -17,6 +16,7 @@ NOT_AGE_APPROPRIATE, REMOTE_INTEGRATION_FAILED, ) +from palace.manager.api.util.flask import get_request_library from palace.manager.core.problem_details import INVALID_INPUT from palace.manager.search.external_search import ExternalSearchIndex from palace.manager.service.redis.redis import Redis @@ -82,7 +82,7 @@ def get_patron_hold( @property def circulation(self) -> CirculationAPI: """Return the appropriate CirculationAPI for the request Library.""" - library_id = flask.request.library.id # type: ignore[attr-defined] + library_id = get_request_library().id return self.manager.circulation_apis[library_id] # type: ignore[no-any-return] @property @@ -103,7 +103,7 @@ def redis_client(self) -> Redis: def load_lane(self, lane_identifier: int | None) -> Lane | WorkList | ProblemDetail: """Turn user input into a Lane object.""" - library_id = flask.request.library.id # type: ignore[attr-defined] + library_id = get_request_library().id lane = None if lane_identifier is None: diff --git a/src/palace/manager/api/controller/index.py b/src/palace/manager/api/controller/index.py index fb2a64d431..d059562a88 100644 --- a/src/palace/manager/api/controller/index.py +++ b/src/palace/manager/api/controller/index.py @@ -1,11 +1,11 @@ from __future__ import annotations -import flask from flask import Response, redirect, url_for from palace.manager.api.controller.circulation_manager import ( CirculationManagerController, ) +from palace.manager.api.util.flask import get_request_library from palace.manager.util.authentication_for_opds import AuthenticationForOPDSDocument from palace.manager.util.problem_detail import ProblemDetail @@ -15,7 +15,7 @@ class IndexController(CirculationManagerController): def __call__(self): # The simple case: the app is equally open to all clients. - library_short_name = flask.request.library.short_name + library_short_name = get_request_library().short_name if not self.has_root_lanes(): return redirect( url_for( @@ -43,7 +43,7 @@ def has_root_lanes(self): :return: A boolean """ - return flask.request.library.has_root_lanes + return get_request_library().has_root_lanes def authenticated_patron_root_lane(self): patron = self.authenticated_patron_from_request() @@ -54,7 +54,7 @@ def authenticated_patron_root_lane(self): return patron.root_lane def appropriate_index_for_patron_type(self): - library_short_name = flask.request.library.short_name + library_short_name = get_request_library().short_name root_lane = self.authenticated_patron_root_lane() if isinstance(root_lane, ProblemDetail): return root_lane diff --git a/src/palace/manager/api/controller/loan.py b/src/palace/manager/api/controller/loan.py index 3db0353165..75067962b8 100644 --- a/src/palace/manager/api/controller/loan.py +++ b/src/palace/manager/api/controller/loan.py @@ -23,6 +23,7 @@ NO_ACTIVE_LOAN_OR_HOLD, NO_LICENSES, ) +from palace.manager.api.util.flask import get_request_library from palace.manager.celery.tasks.patron_activity import sync_patron_activity from palace.manager.core.problem_details import INTERNAL_SERVER_ERROR from palace.manager.feed.acquisition import OPDSAcquisitionFeed @@ -89,7 +90,7 @@ def borrow( book or the license file. """ patron = flask.request.patron # type: ignore[attr-defined] - library = flask.request.library # type: ignore[attr-defined] + library = get_request_library() header = self.authorization_header() credential = self.manager.auth.get_credential_from_header(header) @@ -300,7 +301,7 @@ def fulfill( # There's still a chance this request can succeed, but if not, # we'll be sending out authentication_response. patron = None - library = flask.request.library # type: ignore + library = get_request_library() header = self.authorization_header() credential = self.manager.auth.get_credential_from_header(header) @@ -492,7 +493,7 @@ def detail( self, identifier_type: str, identifier: str ) -> OPDSEntryResponse | ProblemDetail | None: patron = flask.request.patron # type: ignore[attr-defined] - library = flask.request.library # type: ignore[attr-defined] + library = get_request_library() pools = self.load_licensepools(library, identifier_type, identifier) if isinstance(pools, ProblemDetail): return pools diff --git a/src/palace/manager/api/controller/marc.py b/src/palace/manager/api/controller/marc.py index 3114fbca40..cc61426895 100644 --- a/src/palace/manager/api/controller/marc.py +++ b/src/palace/manager/api/controller/marc.py @@ -4,11 +4,11 @@ from dataclasses import dataclass, field from datetime import datetime -import flask from flask import Response from sqlalchemy import select from sqlalchemy.orm import Session +from palace.manager.api.util.flask import get_request_library from palace.manager.integration.goals import Goals from palace.manager.marc.exporter import MarcExporter from palace.manager.service.integration_registry.catalog_services import ( @@ -60,7 +60,7 @@ def __init__( @staticmethod def library() -> Library: - return flask.request.library # type: ignore[no-any-return,attr-defined] + return get_request_library() def has_integration(self, session: Session, library: Library) -> bool: protocols = self.registry.get_protocols(MarcExporter) diff --git a/src/palace/manager/api/controller/opds_feed.py b/src/palace/manager/api/controller/opds_feed.py index 1bd8908b57..f13f67fa06 100644 --- a/src/palace/manager/api/controller/opds_feed.py +++ b/src/palace/manager/api/controller/opds_feed.py @@ -15,6 +15,7 @@ JackpotWorkList, ) from palace.manager.api.problem_details import NO_SUCH_COLLECTION, NO_SUCH_LIST +from palace.manager.api.util.flask import get_request_library from palace.manager.core.app_server import ( load_facets_from_request, load_pagination_from_request, @@ -45,7 +46,7 @@ def groups(self, lane_identifier, feed_class=OPDSAcquisitionFeed): :param feed_class: A replacement for AcquisitionFeed, for use in tests. """ - library = flask.request.library + library = get_request_library() # Special case: a patron with a root lane who attempts to access # the library's top-level WorkList is redirected to their root @@ -128,7 +129,7 @@ def feed(self, lane_identifier, feed_class=OPDSAcquisitionFeed): if isinstance(search_engine, ProblemDetail): return search_engine - library_short_name = flask.request.library.short_name + library_short_name = get_request_library().short_name url = url_for( "feed", lane_identifier=lane_identifier, @@ -159,7 +160,7 @@ def navigation(self, lane_identifier): lane = self.load_lane(lane_identifier) if isinstance(lane, ProblemDetail): return lane - library = flask.request.library + library = get_request_library() library_short_name = library.short_name url = url_for( "navigation_feed", @@ -191,7 +192,7 @@ def crawlable_library_feed(self): """Build or retrieve a crawlable acquisition feed for the request library. """ - library = flask.request.library + library = get_request_library() url = url_for( "crawlable_library_feed", library_short_name=library.short_name, @@ -224,7 +225,7 @@ def crawlable_list_feed(self, list_name): # TODO: A library is not strictly required here, since some # CustomLists aren't associated with a library, but this isn't # a use case we need to support now. - library = flask.request.library + library = get_request_library() list = CustomList.find(self._db, list_name, library=library) if not list: return NO_SUCH_LIST @@ -285,7 +286,7 @@ def _crawlable_feed( ) def _load_search_facets(self, lane): - entrypoints = list(flask.request.library.entrypoints) + entrypoints = list(get_request_library().entrypoints) if len(entrypoints) > 1: # There is more than one enabled EntryPoint. # By default, search them all. @@ -326,7 +327,7 @@ def search(self, lane_identifier, feed_class=OPDSAcquisitionFeed): # Check whether there is a query string -- if not, we want to # send an OpenSearch document explaining how to search. query = flask.request.args.get("q") - library_short_name = flask.request.library.short_name + library_short_name = get_request_library().short_name # Create a function that, when called, generates a URL to the # search controller. @@ -391,7 +392,7 @@ def _qa_feed( :return: A ProblemDetail if there's a problem loading the faceting object; otherwise the return value of `feed_factory`. """ - library = flask.request.library + library = get_request_library() search_engine = self.search_engine if isinstance(search_engine, ProblemDetail): return search_engine diff --git a/src/palace/manager/api/controller/playtime_entries.py b/src/palace/manager/api/controller/playtime_entries.py index 688c7d7aa5..2fd74fd710 100644 --- a/src/palace/manager/api/controller/playtime_entries.py +++ b/src/palace/manager/api/controller/playtime_entries.py @@ -14,11 +14,11 @@ PlaytimeEntriesPostResponse, ) from palace.manager.api.problem_details import NOT_FOUND_ON_REMOTE +from palace.manager.api.util.flask import get_request_library from palace.manager.core.problem_details import INVALID_INPUT from palace.manager.core.query.playtime_entries import PlaytimeEntries from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.identifier import Identifier -from palace.manager.sqlalchemy.model.library import Library from palace.manager.sqlalchemy.model.licensing import LicensePool from palace.manager.sqlalchemy.model.patron import Loan from palace.manager.sqlalchemy.util import get_one @@ -39,7 +39,7 @@ def sha1(msg): class PlaytimeEntriesController(CirculationManagerController): def track_playtimes(self, collection_id, identifier_type, identifier_idn): - library: Library = flask.request.library + library = get_request_library() identifier = get_one( self._db, Identifier, type=identifier_type, identifier=identifier_idn ) diff --git a/src/palace/manager/api/controller/urn_lookup.py b/src/palace/manager/api/controller/urn_lookup.py index 16c1bd4181..ffeaea6a92 100644 --- a/src/palace/manager/api/controller/urn_lookup.py +++ b/src/palace/manager/api/controller/urn_lookup.py @@ -1,7 +1,6 @@ from __future__ import annotations -import flask - +from palace.manager.api.util.flask import get_request_library from palace.manager.core.app_server import ( URNLookupController as CoreURNLookupController, ) @@ -18,7 +17,7 @@ def work_lookup(self, route_name): top-level WorkList, and use it to generate an OPDS lookup feed. """ - library = flask.request.library + library = get_request_library() top_level_worklist = self.manager.top_level_lanes[library.id] annotator = CirculationManagerAnnotator(top_level_worklist) return super().work_lookup(annotator, route_name) diff --git a/src/palace/manager/api/controller/work.py b/src/palace/manager/api/controller/work.py index 445d57de7c..bf2fb19b48 100644 --- a/src/palace/manager/api/controller/work.py +++ b/src/palace/manager/api/controller/work.py @@ -17,6 +17,7 @@ SeriesLane, ) from palace.manager.api.problem_details import NO_SUCH_LANE, NOT_FOUND_ON_REMOTE +from palace.manager.api.util.flask import get_request_library from palace.manager.core.app_server import load_pagination_from_request from palace.manager.core.config import CannotLoadConfiguration from palace.manager.core.metadata_layer import ContributorData @@ -40,7 +41,7 @@ def contributor( self, contributor_name, languages, audiences, feed_class=OPDSAcquisitionFeed ): """Serve a feed of books written by a particular author""" - library = flask.request.library + library = get_request_library() if not contributor_name: return NO_SUCH_LANE.detailed(_("No contributor provided")) @@ -105,7 +106,7 @@ def permalink(self, identifier_type, identifier): returns a single entry while the /works lookup protocol returns a feed containing any number of entries. """ - library = flask.request.library + library = get_request_library() work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work @@ -146,7 +147,7 @@ def related( ): """Serve a groups feed of books related to a given book.""" - library = flask.request.library + library = get_request_library() work = self.load_work(library, identifier_type, identifier) if work is None: return NOT_FOUND_ON_REMOTE @@ -203,7 +204,7 @@ def recommendations( ): """Serve a feed of recommendations related to a given book.""" - library = flask.request.library + library = get_request_library() work = self.load_work(library, identifier_type, identifier) if isinstance(work, ProblemDetail): return work @@ -255,7 +256,7 @@ def recommendations( def series(self, series_name, languages, audiences, feed_class=OPDSAcquisitionFeed): """Serve a feed of books in a given series.""" - library = flask.request.library + library = get_request_library() if not series_name: return NO_SUCH_LANE.detailed(_("No series provided")) diff --git a/src/palace/manager/api/util/flask.py b/src/palace/manager/api/util/flask.py index 67be659363..f7db64c59b 100644 --- a/src/palace/manager/api/util/flask.py +++ b/src/palace/manager/api/util/flask.py @@ -1,10 +1,110 @@ -from flask import Flask +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload + +import flask from sqlalchemy.orm import Session -from palace.manager.api.circulation_manager import CirculationManager +from palace.manager.core.exceptions import PalaceValueError +from palace.manager.sqlalchemy.model.library import Library +from palace.manager.util.sentinel import SentinelType + +if TYPE_CHECKING: + from palace.manager.api.circulation_manager import CirculationManager + + +TVar = TypeVar("TVar") +TDefault = TypeVar("TDefault") + + +@overload +def get_request_var( + name: str, var_cls: type[TVar], *, default: Literal[SentinelType.NotGiven] = ... +) -> TVar: ... + + +@overload +def get_request_var( + name: str, var_cls: type[TVar], *, default: TDefault +) -> TVar | TDefault: ... + + +def get_request_var( + name: str, + var_cls: type[TVar], + *, + default: TDefault | Literal[SentinelType.NotGiven] = SentinelType.NotGiven, +) -> TVar | TDefault: + """ + Retrieve an attribute from the current Flask request object. + + This helper function handles edge cases such as missing request context or unset attributes. + It ensures type checking and provides type hints for the expected attribute type. + + :param name: The name of the attribute to retrieve. + :param var_cls: The expected type of the attribute. + :param default: The default value to return if the attribute is not set or if there is no request context. + + :return: The attribute from the request object, or the default value if provided. + + :raises PalaceValueError: If the attribute is not set or if the attribute type is incorrect, + and no default is provided. + :raises RuntimeError: If there is no request context and no default is provided. + """ + + if default is not SentinelType.NotGiven and not flask.request: + # We are not in a request context, so we can't get the variable + # if we access it, it will raise an error, so we return the default + return default + + try: + var = getattr(flask.request, name) + except AttributeError: + if default is SentinelType.NotGiven: + raise PalaceValueError(f"No '{name}' set on 'flask.request'") + return default + + if not isinstance(var, var_cls): + if default is SentinelType.NotGiven: + raise PalaceValueError( + f"'{name}' on 'flask.request' has incorrect type " + f"'{var.__class__.__name__}' expected '{var_cls.__name__}'", + ) + return default + return var + + +@overload +def get_request_library() -> Library: ... + + +@overload +def get_request_library(*, default: TDefault) -> Library | TDefault: ... + + +def get_request_library( + *, default: TDefault | Literal[SentinelType.NotGiven] = SentinelType.NotGiven +) -> Library | TDefault: + """ + Retrieve the 'library' attribute from the current Flask request object. + + This attribute should be set by using the @has_library or @allows_library decorator + on the route or by calling the BaseCirculationManagerController.library_for_request + method. + + Note: You need to specify a default of None if you want to allow the library to be + None (for example if you are using the @allows_library decorator). + + :param default: The default value to return if the 'library' attribute is not set. + If not provided, a `PalaceValueError` will be raised if the attribute is missing + or has an incorrect type. + + :return: The `Library` object from the request, or the default value if provided. + """ + return get_request_var("library", Library, default=default) -class PalaceFlask(Flask): +class PalaceFlask(flask.Flask): """ A subclass of Flask sets properties used by Palace. @@ -40,7 +140,7 @@ class PalaceFlask(Flask): Palace: You're going to need a stiff drink after this. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._db: Session self.manager: CirculationManager diff --git a/src/palace/manager/core/app_server.py b/src/palace/manager/core/app_server.py index 80bec63b5e..d5074afb5c 100644 --- a/src/palace/manager/core/app_server.py +++ b/src/palace/manager/core/app_server.py @@ -6,7 +6,7 @@ from collections.abc import Callable from functools import wraps from io import BytesIO -from typing import TYPE_CHECKING, ParamSpec, TypeVar +from typing import ParamSpec, TypeVar import flask from flask import Response, make_response, url_for @@ -15,6 +15,7 @@ from palace import manager from palace.manager.api.admin.config import Configuration as AdminUiConfig +from palace.manager.api.util.flask import PalaceFlask, get_request_library from palace.manager.core.problem_details import INVALID_URN from palace.manager.feed.acquisition import LookupAcquisitionFeed, OPDSAcquisitionFeed from palace.manager.sqlalchemy.model.identifier import Identifier @@ -23,9 +24,6 @@ from palace.manager.util.opds_writer import OPDSMessage from palace.manager.util.problem_detail import BaseProblemDetailException, ProblemDetail -if TYPE_CHECKING: - from palace.manager.api.util.flask import PalaceFlask - def load_facets_from_request( facet_config=None, @@ -51,7 +49,7 @@ def load_facets_from_request( kwargs = base_class_constructor_kwargs or dict() get_arg = flask.request.args.get get_header = flask.request.headers.get - library = flask.request.library + library = get_request_library() facet_config = facet_config or library return base_class.from_request( library, diff --git a/src/palace/manager/util/sentinel.py b/src/palace/manager/util/sentinel.py new file mode 100644 index 0000000000..c2f9f32201 --- /dev/null +++ b/src/palace/manager/util/sentinel.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from enum import Enum + + +class SentinelType(Enum): + """ + Sentinel value for when a variable is not given. + + We use this so we can differentiate between a variable that is not given + and a variable that is given as None. If https://peps.python.org/pep-0661/ + is accepted, we should update this is use a proper sentinel value. For now, + we use this enum, since we can type check it. + + This solution is based on discussion here: + https://github.com/python/typing/issues/236#issuecomment-227180301 + + It can be type hinted as: Literal[SentinelType.NotGiven] + """ + + NotGiven = "NotGiven" diff --git a/tests/fixtures/api_controller.py b/tests/fixtures/api_controller.py index 6466637fab..66da562ac0 100644 --- a/tests/fixtures/api_controller.py +++ b/tests/fixtures/api_controller.py @@ -6,7 +6,6 @@ from contextlib import contextmanager from typing import Any -import flask import pytest from sqlalchemy.orm import Session from werkzeug.datastructures import Authorization @@ -250,7 +249,7 @@ def request_context_with_library(self, route, *args, **kwargs): else: library = self.db.default_library() with self.app.test_request_context(route, *args, **kwargs) as c: - flask.request.library = library + setattr(c.request, "library", library) yield c diff --git a/tests/fixtures/flask.py b/tests/fixtures/flask.py index fe06cc3016..43d8cffa47 100644 --- a/tests/fixtures/flask.py +++ b/tests/fixtures/flask.py @@ -4,7 +4,6 @@ from contextlib import contextmanager from typing import Any -import flask import pytest from flask.ctx import RequestContext from flask_babel import Babel @@ -43,10 +42,10 @@ def test_request_context( ) -> Generator[RequestContext]: with self.app.test_request_context(*args, **kwargs) as c: self.db.session.begin_nested() - flask.request.library = library # type: ignore[attr-defined] - flask.request.admin = admin # type: ignore[attr-defined] - flask.request.form = ImmutableMultiDict() - flask.request.files = ImmutableMultiDict() + setattr(c.request, "library", library) + setattr(c.request, "admin", admin) + setattr(c.request, "form", ImmutableMultiDict()) + setattr(c.request, "files", ImmutableMultiDict()) try: yield c finally: diff --git a/tests/manager/api/admin/controller/test_custom_lists.py b/tests/manager/api/admin/controller/test_custom_lists.py index 743ac08474..e446b34cac 100644 --- a/tests/manager/api/admin/controller/test_custom_lists.py +++ b/tests/manager/api/admin/controller/test_custom_lists.py @@ -249,8 +249,8 @@ def test_custom_lists_post_errors( library = admin_librarian_fixture.ctrl.db.library() with admin_librarian_fixture.request_context_with_admin( "/", method="POST", admin=admin - ): - flask.request.library = library # type: ignore[attr-defined] + ) as ctx: + setattr(ctx.request, "library", library) form = ImmutableMultiDict( [ ("name", "name"), diff --git a/tests/manager/api/admin/controller/test_lanes.py b/tests/manager/api/admin/controller/test_lanes.py index 1cf72afd8a..140b3b8c3a 100644 --- a/tests/manager/api/admin/controller/test_lanes.py +++ b/tests/manager/api/admin/controller/test_lanes.py @@ -1,6 +1,5 @@ import json -import flask import pytest from werkzeug.datastructures import ImmutableMultiDict @@ -84,8 +83,7 @@ def test_lanes_get(self, alm_fixture: AdminLibraryManagerFixture): lane_for_list.priority = 2 lane_for_list.size = 1 - with alm_fixture.request_context_with_library_and_admin("/"): - flask.request.library = library # type: ignore[attr-defined] + with alm_fixture.request_context_with_library_and_admin("/", library=library): # The admin is not a librarian for this library. pytest.raises( AdminNotAuthorized, @@ -135,13 +133,17 @@ def test_lanes_get(self, alm_fixture: AdminLibraryManagerFixture): assert True == list_info.get("inherit_parent_restrictions") def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict([]) + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict([]) response = alm_fixture.manager.admin_lanes_controller.lanes() assert NO_DISPLAY_NAME_FOR_LANE == response - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("display_name", "lane"), ] @@ -154,8 +156,10 @@ def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): ) list.library = alm_fixture.ctrl.db.default_library() - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("id", "12345"), ("display_name", "lane"), @@ -166,9 +170,10 @@ def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): assert MISSING_LANE == response library = alm_fixture.ctrl.db.library() - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.library = library # type: ignore[attr-defined] - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST", library=library + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("display_name", "lane"), ("custom_list_ids", json.dumps([list.id])), @@ -183,8 +188,10 @@ def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): lane2 = alm_fixture.ctrl.db.lane("lane2") lane1.customlists += [list] - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("id", lane1.id), ("display_name", "lane2"), @@ -194,8 +201,10 @@ def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): response = alm_fixture.manager.admin_lanes_controller.lanes() assert LANE_WITH_PARENT_AND_DISPLAY_NAME_ALREADY_EXISTS == response - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("display_name", "lane2"), ("custom_list_ids", json.dumps([list.id])), @@ -204,8 +213,10 @@ def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): response = alm_fixture.manager.admin_lanes_controller.lanes() assert LANE_WITH_PARENT_AND_DISPLAY_NAME_ALREADY_EXISTS == response - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("parent_id", "12345"), ("display_name", "lane"), @@ -215,8 +226,10 @@ def test_lanes_post_errors(self, alm_fixture: AdminLibraryManagerFixture): response = alm_fixture.manager.admin_lanes_controller.lanes() assert MISSING_LANE.uri == response.uri - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("parent_id", lane1.id), ("display_name", "lane"), @@ -236,8 +249,10 @@ def test_lanes_create(self, alm_fixture: AdminLibraryManagerFixture): parent = alm_fixture.ctrl.db.lane("parent") sibling = alm_fixture.ctrl.db.lane("sibling", parent=parent) - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("parent_id", parent.id), ("display_name", "lane"), @@ -274,8 +289,8 @@ def test_lanes_create_shared_list(self, alm_fixture: AdminLibraryManagerFixture) with alm_fixture.request_context_with_library_and_admin( "/", method="POST", library=library - ): - flask.request.form = ImmutableMultiDict( + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("display_name", "lane"), ("custom_list_ids", json.dumps([list.id])), @@ -292,8 +307,8 @@ def test_lanes_create_shared_list(self, alm_fixture: AdminLibraryManagerFixture) with alm_fixture.request_context_with_library_and_admin( "/", method="POST", library=library - ): - flask.request.form = ImmutableMultiDict( + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("display_name", "lane"), ("custom_list_ids", json.dumps([list.id])), @@ -330,8 +345,10 @@ def test_lanes_edit(self, alm_fixture: AdminLibraryManagerFixture): # are two works in the lane. assert 0 == lane.size - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("id", str(lane.id)), ("display_name", "new name"), @@ -353,8 +370,10 @@ def test_default_lane_edit(self, alm_fixture: AdminLibraryManagerFixture): """Default lanes only allow the display_name to be edited""" lane: Lane = alm_fixture.ctrl.db.lane("default") customlist, _ = alm_fixture.ctrl.db.customlist() - with alm_fixture.request_context_with_library_and_admin("/", method="POST"): - flask.request.form = ImmutableMultiDict( + with alm_fixture.request_context_with_library_and_admin( + "/", method="POST" + ) as ctx: + ctx.request.form = ImmutableMultiDict( [ ("id", str(lane.id)), ("parent_id", "12345"), @@ -390,8 +409,9 @@ def test_lane_delete_success(self, alm_fixture: AdminLibraryManagerFixture): .count() ) - with alm_fixture.request_context_with_library_and_admin("/", method="DELETE"): - flask.request.library = library # type: ignore[attr-defined] + with alm_fixture.request_context_with_library_and_admin( + "/", method="DELETE", library=library + ): response = alm_fixture.manager.admin_lanes_controller.lane(lane.id) assert 200 == response.status_code @@ -426,8 +446,9 @@ def test_lane_delete_success(self, alm_fixture: AdminLibraryManagerFixture): .count() ) - with alm_fixture.request_context_with_library_and_admin("/", method="DELETE"): - flask.request.library = library # type: ignore[attr-defined] + with alm_fixture.request_context_with_library_and_admin( + "/", method="DELETE", library=library + ): response = alm_fixture.manager.admin_lanes_controller.lane(lane.id) assert 200 == response.status_code @@ -454,8 +475,9 @@ def test_lane_delete_errors(self, alm_fixture: AdminLibraryManagerFixture): lane = alm_fixture.ctrl.db.lane("lane") library = alm_fixture.ctrl.db.library() - with alm_fixture.request_context_with_library_and_admin("/", method="DELETE"): - flask.request.library = library # type: ignore[attr-defined] + with alm_fixture.request_context_with_library_and_admin( + "/", method="DELETE", library=library + ): pytest.raises( AdminNotAuthorized, alm_fixture.manager.admin_lanes_controller.lane, @@ -526,8 +548,7 @@ def test_reset(self, alm_fixture: AdminLibraryManagerFixture): library = alm_fixture.ctrl.db.library() old_lane = alm_fixture.ctrl.db.lane("old lane", library=library) - with alm_fixture.request_context_with_library_and_admin("/"): - flask.request.library = library # type: ignore[attr-defined] + with alm_fixture.request_context_with_library_and_admin("/", library=library): pytest.raises( AdminNotAuthorized, alm_fixture.manager.admin_lanes_controller.reset, @@ -570,9 +591,10 @@ def test_change_order(self, alm_fixture: AdminLibraryManagerFixture): {"id": parent1.id}, ] - with alm_fixture.request_context_with_library_and_admin("/"): - flask.request.library = library # type: ignore[attr-defined] - flask.request.data = json.dumps(new_order).encode() + with alm_fixture.request_context_with_library_and_admin( + "/", library=library + ) as ctx: + ctx.request.data = json.dumps(new_order).encode() pytest.raises( AdminNotAuthorized, diff --git a/tests/manager/api/admin/controller/test_work_editor.py b/tests/manager/api/admin/controller/test_work_editor.py index f62480c532..9b69800a35 100644 --- a/tests/manager/api/admin/controller/test_work_editor.py +++ b/tests/manager/api/admin/controller/test_work_editor.py @@ -767,8 +767,7 @@ def test_suppress(self, work_fixture: WorkFixture): ) # test no library - with work_fixture.request_context_with_library_and_admin("/"): - flask.request.library = None # type: ignore[attr-defined] + with work_fixture.request_context_with_library_and_admin("/", library=None): with pytest.raises(ProblemDetailException) as exc: work_fixture.manager.admin_work_controller.suppress( lp.identifier.type, lp.identifier.identifier @@ -817,8 +816,7 @@ def test_unsuppress(self, work_fixture: WorkFixture): ) # test no library - with work_fixture.request_context_with_library_and_admin("/"): - flask.request.library = None # type: ignore[attr-defined] + with work_fixture.request_context_with_library_and_admin("/", library=None): with pytest.raises(ProblemDetailException) as exc: work_fixture.manager.admin_work_controller.unsuppress( lp.identifier.type, lp.identifier.identifier diff --git a/tests/manager/api/controller/test_base.py b/tests/manager/api/controller/test_base.py index 0b64c41723..ec81bee5fd 100644 --- a/tests/manager/api/controller/test_base.py +++ b/tests/manager/api/controller/test_base.py @@ -497,19 +497,21 @@ def test_library_for_request( value = circulation_fixture.controller.library_for_request("not-a-library") assert LIBRARY_NOT_FOUND == value - with circulation_fixture.app.test_request_context("/"): + with circulation_fixture.app.test_request_context("/") as ctx: value = circulation_fixture.controller.library_for_request( circulation_fixture.db.default_library().short_name ) assert circulation_fixture.db.default_library() == value - assert circulation_fixture.db.default_library() == flask.request.library # type: ignore + assert circulation_fixture.db.default_library() == getattr( + ctx.request, "library" + ) # If you don't specify a library, the default library is used. - with circulation_fixture.app.test_request_context("/"): + with circulation_fixture.app.test_request_context("/") as ctx: value = circulation_fixture.controller.library_for_request(None) expect_default = Library.default(circulation_fixture.db.session) assert expect_default == value - assert expect_default == flask.request.library # type: ignore + assert expect_default == getattr(ctx.request, "library") def test_load_lane(self, circulation_fixture: CirculationControllerFixture): # Verify that requests for specific lanes are mapped to diff --git a/tests/manager/api/controller/test_index.py b/tests/manager/api/controller/test_index.py index b88227239f..4af1d638f6 100644 --- a/tests/manager/api/controller/test_index.py +++ b/tests/manager/api/controller/test_index.py @@ -1,7 +1,5 @@ import json -import flask - from palace.manager.sqlalchemy.model.lane import Lane from palace.manager.util.authentication_for_opds import AuthenticationForOPDSDocument from tests.fixtures.api_controller import CirculationControllerFixture @@ -9,8 +7,8 @@ class TestIndexController: def test_simple_redirect(self, circulation_fixture: CirculationControllerFixture): - with circulation_fixture.app.test_request_context("/"): - flask.request.library = circulation_fixture.library # type: ignore + with circulation_fixture.app.test_request_context("/") as ctx: + setattr(ctx.request, "library", circulation_fixture.library) response = circulation_fixture.manager.index_controller() assert 302 == response.status_code assert "http://localhost/default/groups/" == response.headers["location"] diff --git a/tests/manager/api/controller/test_scopedsession.py b/tests/manager/api/controller/test_scopedsession.py index b07e68493a..128b29a507 100644 --- a/tests/manager/api/controller/test_scopedsession.py +++ b/tests/manager/api/controller/test_scopedsession.py @@ -1,9 +1,9 @@ from collections.abc import Generator from contextlib import contextmanager -from unittest.mock import MagicMock +from unittest.mock import create_autospec -import flask import pytest +from flask.ctx import RequestContext from sqlalchemy.orm import Session from typing_extensions import Self @@ -11,6 +11,7 @@ from palace.manager.sqlalchemy.flask_sqlalchemy_session import current_session from palace.manager.sqlalchemy.model.datasource import DataSource from palace.manager.sqlalchemy.model.identifier import Identifier +from palace.manager.sqlalchemy.model.library import Library from tests.fixtures.database import DatabaseFixture from tests.fixtures.services import ServicesFixture from tests.mocks.circulation import MockCirculationManager @@ -27,7 +28,7 @@ def __init__( with db_fixture.patch_engine(): initialize_database() self.app.manager = MockCirculationManager(app._db, services.services) - self.mock_library = MagicMock() + self.mock_library = create_autospec(Library) self.mock_library.has_root_lanes = False def _cleanup(self) -> None: @@ -45,10 +46,10 @@ def fixture( fixture._cleanup() @contextmanager - def request_context(self, path: str) -> Generator[None, None, None]: + def request_context(self, path: str) -> Generator[RequestContext]: with self.app.test_request_context(path) as ctx: - ctx.request.library = self.mock_library # type: ignore[attr-defined] - yield + setattr(ctx.request, "library", self.mock_library) + yield ctx @pytest.fixture @@ -135,7 +136,6 @@ def test_scoped_session( # The controller still works in the new request context - # nothing it needs is associated with the previous scoped # session. - flask.request.library = scoped_session_fixture.mock_library # type: ignore[attr-defined] response = app.manager.index_controller() assert 302 == response.status_code diff --git a/tests/manager/api/test_authenticator.py b/tests/manager/api/test_authenticator.py index 87d90f3648..69b2ea3b2c 100644 --- a/tests/manager/api/test_authenticator.py +++ b/tests/manager/api/test_authenticator.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Literal, cast from unittest.mock import MagicMock, PropertyMock, patch -import flask import pytest from _pytest._code import ExceptionInfo from flask import url_for @@ -588,8 +587,8 @@ def decode_bearer_token(self, *args, **kwargs): # This new library isn't in the authenticator. l3 = db.library(short_name="l3") - with app.test_request_context("/"): - flask.request.library = l3 # type:ignore + with app.test_request_context("/") as ctx: + setattr(ctx.request, "library", l3) assert LIBRARY_NOT_FOUND == auth.authenticated_patron(db.session, {}) assert LIBRARY_NOT_FOUND == auth.create_authentication_document() assert LIBRARY_NOT_FOUND == auth.create_authentication_headers() @@ -597,8 +596,8 @@ def decode_bearer_token(self, *args, **kwargs): assert LIBRARY_NOT_FOUND == auth.create_bearer_token() # The other libraries are in the authenticator. - with app.test_request_context("/"): - flask.request.library = l1 # type:ignore + with app.test_request_context("/") as ctx: + setattr(ctx.request, "library", l1) assert "authenticated patron for l1" == auth.authenticated_patron( db.session, {} ) @@ -613,8 +612,8 @@ def decode_bearer_token(self, *args, **kwargs): assert "bearer token for l1" == auth.create_bearer_token() assert "decoded bearer token for l1" == auth.decode_bearer_token() - with app.test_request_context("/"): - flask.request.library = l2 # type:ignore + with app.test_request_context("/") as ctx: + setattr(ctx.request, "library", l2) assert "authenticated patron for l2" == auth.authenticated_patron( db.session, {} ) diff --git a/tests/manager/api/test_circulationapi.py b/tests/manager/api/test_circulationapi.py index ca2b1aeed8..0568eccd98 100644 --- a/tests/manager/api/test_circulationapi.py +++ b/tests/manager/api/test_circulationapi.py @@ -960,11 +960,11 @@ def assert_event(inp, outp): # We must run the rest of the tests in a simulated Flask request # context. app = Flask(__name__) - with app.test_request_context(): + with app.test_request_context() as ctx: # The request library takes precedence over the Library # associated with the CirculationAPI (though this # shouldn't happen). - flask.request.library = l2 # type: ignore + setattr(ctx.request, "library", l2) assert_event( (None, None, "event"), ( @@ -976,11 +976,11 @@ def assert_event(inp, outp): ), ) - with app.test_request_context(): + with app.test_request_context() as ctx: # The library of the request patron also takes precedence # over both (though again, this shouldn't happen). - flask.request.library = l1 # type: ignore - flask.request.patron = p2 # type: ignore + setattr(ctx.request, "library", l1) + setattr(ctx.request, "patron", p2) assert_event( (None, None, "event"), ( diff --git a/tests/manager/api/util/test_flask.py b/tests/manager/api/util/test_flask.py new file mode 100644 index 0000000000..98cf041dd9 --- /dev/null +++ b/tests/manager/api/util/test_flask.py @@ -0,0 +1,35 @@ +import pytest + +from palace.manager.api.util.flask import get_request_var +from palace.manager.core.exceptions import PalaceValueError +from tests.fixtures.flask import FlaskAppFixture + + +class TestGetRequestVar: + def test_no_request_context(self) -> None: + # If we supply a default, we get the default if there is no request context. + assert get_request_var("foo", str, default="bar") == "bar" + + # If we don't supply a default, we get the normal RuntimeError. + with pytest.raises(RuntimeError, match="Working outside of request context"): + get_request_var("foo", str) + + def test_no_var_set(self, flask_app_fixture: FlaskAppFixture) -> None: + with flask_app_fixture.test_request_context(): + assert get_request_var("foo", str, default=None) is None + + with pytest.raises( + PalaceValueError, match="No 'foo' set on 'flask.request'" + ): + get_request_var("foo", str) + + def test_var_set_to_wrong_type(self, flask_app_fixture: FlaskAppFixture) -> None: + with flask_app_fixture.test_request_context() as ctx: + setattr(ctx.request, "foo", 123) + + assert get_request_var("foo", str, default=None) is None + + with pytest.raises( + PalaceValueError, match="incorrect type 'int' expected 'str'" + ): + get_request_var("foo", str) diff --git a/tests/manager/core/test_app_server.py b/tests/manager/core/test_app_server.py index 8497564497..fd58eefd97 100644 --- a/tests/manager/core/test_app_server.py +++ b/tests/manager/core/test_app_server.py @@ -348,46 +348,33 @@ def test_permalink(self, urn_lookup_controller_fixture: URNLookupControllerFixtu assert work.title in response_data -class LoadMethodsFixture: - transaction: DatabaseTransactionFixture - app: Flask - - -@pytest.fixture() -def load_methods_fixture( - db, -) -> LoadMethodsFixture: - data = LoadMethodsFixture() - data.transaction = db - data.app = Flask(LoadMethodsFixture.__name__) - Babel(data.app) - return data - - class TestLoadMethods: def test_load_facets_from_request( - self, load_methods_fixture: LoadMethodsFixture, library_fixture: LibraryFixture + self, + flask_app_fixture: FlaskAppFixture, + db: DatabaseTransactionFixture, + library_fixture: LibraryFixture, ): - fixture, data = load_methods_fixture, load_methods_fixture.transaction - # The library has two EntryPoints enabled. settings = library_fixture.mock_settings() settings.enabled_entry_points = [ EbooksEntryPoint.INTERNAL_NAME, AudiobooksEntryPoint.INTERNAL_NAME, ] - library = data.library(settings=settings) + library = db.library(settings=settings) - with fixture.app.test_request_context("/?order=%s" % Facets.ORDER_TITLE): - flask.request.library = library # type: ignore[attr-defined] + with flask_app_fixture.test_request_context( + "/?order=%s" % Facets.ORDER_TITLE, library=library + ): facets = load_facets_from_request() assert Facets.ORDER_TITLE == facets.order # Enabled facets are passed in to the newly created Facets, # in case the load method received a custom config. assert facets.facets_enabled_at_init is not None - with fixture.app.test_request_context("/?order=bad_facet"): - flask.request.library = library # type: ignore[attr-defined] + with flask_app_fixture.test_request_context( + "/?order=bad_facet", library=library + ): problemdetail = load_facets_from_request() assert INVALID_INPUT.uri == problemdetail.uri @@ -396,16 +383,18 @@ def test_load_facets_from_request( # configured on the present library. worklist = WorkList() worklist.initialize(library) - with fixture.app.test_request_context("/?entrypoint=Audio"): - flask.request.library = library # type: ignore[attr-defined] + with flask_app_fixture.test_request_context( + "/?entrypoint=Audio", library=library + ): facets = load_facets_from_request(worklist=worklist) assert AudiobooksEntryPoint == facets.entrypoint assert facets.entrypoint_is_default is False # If the requested EntryPoint not configured, the default # EntryPoint is used. - with fixture.app.test_request_context("/?entrypoint=NoSuchEntryPoint"): - flask.request.library = library # type: ignore[attr-defined] + with flask_app_fixture.test_request_context( + "/?entrypoint=NoSuchEntryPoint", library=library + ): default_entrypoint = object() facets = load_facets_from_request( worklist=worklist, default_entrypoint=default_entrypoint @@ -415,20 +404,19 @@ def test_load_facets_from_request( # Load a SearchFacets object that pulls information from an # HTTP header. - with fixture.app.test_request_context("/", headers={"Accept-Language": "ja"}): - flask.request.library = data.default_library() # type: ignore[attr-defined] + with flask_app_fixture.test_request_context( + "/", headers={"Accept-Language": "ja"}, library=library + ): facets = load_facets_from_request(base_class=SearchFacets) assert ["jpn"] == facets.languages def test_load_facets_from_request_class_instantiation( - self, load_methods_fixture: LoadMethodsFixture + self, flask_app_fixture: FlaskAppFixture, db: DatabaseTransactionFixture ): """The caller of load_facets_from_request() can specify a class other than Facets to call from_request() on. """ - fixture, data = load_methods_fixture, load_methods_fixture.transaction - class MockFacets: called_with: dict @@ -439,19 +427,14 @@ def from_request(*args, **kwargs): return facets kwargs = dict(some_arg="some value") - with fixture.app.test_request_context(""): - flask.request.library = data.default_library() # type: ignore[attr-defined] + with flask_app_fixture.test_request_context("", library=db.default_library()): facets = load_facets_from_request( None, None, base_class=MockFacets, base_class_constructor_kwargs=kwargs ) assert isinstance(facets, MockFacets) assert "some value" == facets.called_with["some_arg"] - def test_load_pagination_from_request( - self, load_methods_fixture: LoadMethodsFixture - ): - fixture = load_methods_fixture - + def test_load_pagination_from_request(self, flask_app_fixture: FlaskAppFixture): # Verify that load_pagination_from_request instantiates a # pagination object of the specified class (Pagination, by # default.) @@ -464,7 +447,7 @@ def from_request(cls, get_arg, default_size, **kwargs): cls.called_with = (get_arg, default_size, kwargs) return "I'm a pagination object!" - with fixture.app.test_request_context("/"): + with flask_app_fixture.test_request_context("/"): # Call load_pagination_from_request and verify that # Mock.from_request was called with the arguments we expect. extra_kwargs = dict(extra="kwarg") @@ -478,13 +461,13 @@ def from_request(cls, get_arg, default_size, **kwargs): # If no default size is specified, we trust from_request to # use the class default. - with fixture.app.test_request_context("/"): + with flask_app_fixture.test_request_context("/"): pagination = load_pagination_from_request(base_class=Mock) assert (flask.request.args.get, None, {}) == Mock.called_with # Now try a real case using the default pagination class, # Pagination - with fixture.app.test_request_context("/?size=50&after=10"): + with flask_app_fixture.test_request_context("/?size=50&after=10"): pagination = load_pagination_from_request() assert isinstance(pagination, Pagination) assert 50 == pagination.size