Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helper function to get flask.request.patron #2178

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/palace/manager/api/controller/analytics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import flask
from flask import Response

from palace.manager.api.controller.circulation_manager import (
CirculationManagerController,
)
from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent
from palace.manager.util.datetime_helpers import utc_now
from palace.manager.util.problem_detail import ProblemDetail
Expand All @@ -21,8 +20,8 @@ def track_event(self, identifier_type, identifier, event_type):
if event_type in CirculationEvent.CLIENT_EVENTS:
library = get_request_library()
# Authentication on the AnalyticsController is optional,
# so flask.request.patron may or may not be set.
patron = getattr(flask.request, "patron", None)
# so we may not have a patron.
patron = get_request_patron(default=None)
neighborhood = None
if patron:
neighborhood = getattr(patron, "neighborhood", None)
Expand Down
5 changes: 3 additions & 2 deletions src/palace/manager/api/controller/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CirculationManagerController,
)
from palace.manager.api.problem_details import NO_ANNOTATION
from palace.manager.api.util.flask import get_request_patron
from palace.manager.sqlalchemy.model.identifier import Identifier
from palace.manager.sqlalchemy.model.patron import Annotation
from palace.manager.sqlalchemy.util import get_one
Expand All @@ -30,7 +31,7 @@ def container(self, identifier=None, accept_post=True):
if flask.request.method == "HEAD":
return Response(status=200, headers=headers)

patron = flask.request.patron
patron = get_request_patron()

if flask.request.method == "GET":
headers["Link"] = [
Expand Down Expand Up @@ -78,7 +79,7 @@ def detail(self, annotation_id):
if flask.request.method == "HEAD":
return Response(status=200, headers=headers)

patron = flask.request.patron
patron = get_request_patron()

annotation = get_one(
self._db, Annotation, patron=patron, id=annotation_id, active=True
Expand Down
7 changes: 4 additions & 3 deletions src/palace/manager/api/controller/device_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEVICE_TOKEN_NOT_FOUND,
DEVICE_TOKEN_TYPE_INVALID,
)
from palace.manager.api.util.flask import get_request_patron
from palace.manager.sqlalchemy.model.devicetokens import (
DeviceToken,
DuplicateDeviceTokenError,
Expand All @@ -20,7 +21,7 @@

class DeviceTokensController(CirculationManagerController):
def get_patron_device(self):
patron = flask.request.patron
patron = get_request_patron()
device_token = flask.request.args["device_token"]
token: DeviceToken = (
self._db.query(DeviceToken)
Expand All @@ -35,7 +36,7 @@ def get_patron_device(self):
return dict(token_type=token.token_type, device_token=token.device_token), 200

def create_patron_device(self):
patron = flask.request.patron
patron = get_request_patron()
device_token = flask.request.json["device_token"]
token_type = flask.request.json["token_type"]

Expand All @@ -49,7 +50,7 @@ def create_patron_device(self):
return "", 201

def delete_patron_device(self):
patron = flask.request.patron
patron = get_request_patron()
device_token = flask.request.json["device_token"]
token_type = flask.request.json["token_type"]

Expand Down
10 changes: 5 additions & 5 deletions src/palace/manager/api/controller/loan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
NO_ACTIVE_LOAN_OR_HOLD,
NO_LICENSES,
)
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.celery.tasks.patron_activity import sync_patron_activity
from palace.manager.core.problem_details import INTERNAL_SERVER_ERROR
from palace.manager.feed.acquisition import OPDSAcquisitionFeed
Expand All @@ -46,7 +46,7 @@

:return: A Response containing an OPDS feed with up-to-date information.
"""
patron: Patron = flask.request.patron # type: ignore[attr-defined]
patron: Patron = get_request_patron()

try:
# Parse the refresh query parameter as a boolean.
Expand Down Expand Up @@ -89,7 +89,7 @@
"http://opds-spec.org/acquisition", which can be used to fetch the
book or the license file.
"""
patron = flask.request.patron # type: ignore[attr-defined]
patron = get_request_patron()
library = get_request_library()

header = self.authorization_header()
Expand Down Expand Up @@ -428,7 +428,7 @@
return self.circulation.can_fulfill_without_loan(patron, pool, lpdm)

def revoke(self, license_pool_id: int) -> OPDSEntryResponse | ProblemDetail:
patron = flask.request.patron # type: ignore[attr-defined]
patron = get_request_patron()
pool = self.load_licensepool(license_pool_id)
if isinstance(pool, ProblemDetail):
return pool
Expand Down Expand Up @@ -492,7 +492,7 @@
def detail(
self, identifier_type: str, identifier: str
) -> OPDSEntryResponse | ProblemDetail | None:
patron = flask.request.patron # type: ignore[attr-defined]
patron = get_request_patron()

Check warning on line 495 in src/palace/manager/api/controller/loan.py

View check run for this annotation

Codecov / codecov/patch

src/palace/manager/api/controller/loan.py#L495

Added line #L495 was not covered by tests
library = get_request_library()
pools = self.load_licensepools(library, identifier_type, identifier)
if isinstance(pools, ProblemDetail):
Expand Down
5 changes: 2 additions & 3 deletions src/palace/manager/api/controller/patron_activity_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import uuid

import flask
from flask import Response

from palace.manager.sqlalchemy.model.patron import Patron
from palace.manager.api.util.flask import get_request_patron


class PatronActivityHistoryController:

def reset_statistics_uuid(self) -> Response:
"""Resets the patron's the statistics UUID that links the patron to past activity thus effectively erasing the
link between activity history and patron."""
patron: Patron = flask.request.patron # type: ignore [attr-defined]
patron = get_request_patron()
patron.uuid = uuid.uuid4()
return Response("UUID reset", 200)
5 changes: 3 additions & 2 deletions src/palace/manager/api/controller/patron_auth_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
)
from palace.manager.api.model.patron_auth import PatronAuthAccessToken
from palace.manager.api.problem_details import PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE
from palace.manager.api.util.flask import get_request_patron
from palace.manager.util.log import LoggerMixin
from palace.manager.util.problem_detail import ProblemDetailException


class PatronAuthTokenController(CirculationManagerController, LoggerMixin):
def get_token(self):
"""Create a Patron Auth access token for an authenticated patron"""
patron = flask.request.patron
patron = get_request_patron(default=None)
auth = flask.request.authorization
token_expiry = 3600

if not patron or auth.type.lower() != "basic":
if patron is None or auth is None or auth.type.lower() != "basic":
return PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE

try:
Expand Down
4 changes: 2 additions & 2 deletions src/palace/manager/api/controller/playtime_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
PlaytimeEntriesPostResponse,
)
from palace.manager.api.problem_details import NOT_FOUND_ON_REMOTE
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.core.problem_details import INVALID_INPUT
from palace.manager.core.query.playtime_entries import PlaytimeEntries
from palace.manager.sqlalchemy.model.collection import Collection
Expand Down Expand Up @@ -77,7 +77,7 @@ def track_playtimes(self, collection_id, identifier_type, identifier_idn):
.join(LicensePool)
.where(
LicensePool.identifier == identifier,
Loan.patron == flask.request.patron,
Loan.patron == get_request_patron(),
Loan.start <= entry_max_start_time,
or_(Loan.end > entry_min_end_time, Loan.end == None),
)
Expand Down
3 changes: 2 additions & 1 deletion src/palace/manager/api/controller/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from palace.manager.api.controller.circulation_manager import (
CirculationManagerController,
)
from palace.manager.api.util.flask import get_request_patron
from palace.manager.core.user_profile import ProfileController as CoreProfileController
from palace.manager.util.problem_detail import ProblemDetail

Expand All @@ -21,7 +22,7 @@ def _controller(self, patron):

def protocol(self):
"""Handle a UPMP request."""
patron = flask.request.patron
patron = get_request_patron()
controller = self._controller(patron)
if flask.request.method == "GET":
result = controller.get()
Expand Down
4 changes: 2 additions & 2 deletions src/palace/manager/api/controller/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SeriesLane,
)
from palace.manager.api.problem_details import NO_SUCH_LANE, NOT_FOUND_ON_REMOTE
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.core.app_server import load_pagination_from_request
from palace.manager.core.config import CannotLoadConfiguration
from palace.manager.core.metadata_layer import ContributorData
Expand Down Expand Up @@ -111,7 +111,7 @@ def permalink(self, identifier_type, identifier):
if isinstance(work, ProblemDetail):
return work

patron = flask.request.patron
patron = get_request_patron(default=None)

if patron:
pools = self.load_licensepools(library, identifier_type, identifier)
Expand Down
28 changes: 28 additions & 0 deletions src/palace/manager/api/util/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from palace.manager.core.exceptions import PalaceValueError
from palace.manager.sqlalchemy.model.library import Library
from palace.manager.sqlalchemy.model.patron import Patron
from palace.manager.util.sentinel import SentinelType

if TYPE_CHECKING:
Expand Down Expand Up @@ -104,6 +105,33 @@ def get_request_library(
return get_request_var("library", Library, default=default)


@overload
def get_request_patron() -> Patron: ...


@overload
def get_request_patron(*, default: TDefault) -> Patron | TDefault: ...


def get_request_patron(
*, default: TDefault | Literal[SentinelType.NotGiven] = SentinelType.NotGiven
) -> Patron | TDefault:
"""
Retrieve the 'patron' attribute from the current Flask request object.

This attribute should be set by using the @requires_auth or @allows_auth decorator
on the route or by calling the BaseCirculationManagerController.authenticated_patron_from_request
method.

:param default: The default value to return if the 'patron' attribute is not set
or if there is no request context. If not provided, a `PalaceValueError` will be
raised if the attribute is missing or has an incorrect type.

:return: The `Patron` object from the request, or the default value if provided.
"""
return get_request_var("patron", Patron, default=default)


class PalaceFlask(flask.Flask):
"""
A subclass of Flask sets properties used by Palace.
Expand Down
3 changes: 1 addition & 2 deletions tests/fixtures/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def __init__(self, name):

def authenticated_patron_from_request(self):
if self.authenticated:
patron = object()
flask.request.patron = self.AUTHENTICATED_PATRON
setattr(flask.request, "patron", self.AUTHENTICATED_PATRON)
return self.AUTHENTICATED_PATRON
else:
return flask.Response(
Expand Down
3 changes: 3 additions & 0 deletions tests/fixtures/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from palace.manager.api.util.flask import PalaceFlask
from palace.manager.sqlalchemy.model.admin import Admin, AdminRole
from palace.manager.sqlalchemy.model.library import Library
from palace.manager.sqlalchemy.model.patron import Patron
from palace.manager.sqlalchemy.util import get_one_or_create
from tests.fixtures.database import DatabaseTransactionFixture

Expand All @@ -38,12 +39,14 @@ def test_request_context(
*args: Any,
admin: Admin | None = None,
library: Library | None = None,
patron: Patron | None = None,
**kwargs: Any,
) -> Generator[RequestContext]:
with self.app.test_request_context(*args, **kwargs) as c:
self.db.session.begin_nested()
setattr(c.request, "library", library)
setattr(c.request, "admin", admin)
setattr(c.request, "patron", patron)
setattr(c.request, "form", ImmutableMultiDict())
setattr(c.request, "files", ImmutableMultiDict())
try:
Expand Down
5 changes: 2 additions & 3 deletions tests/manager/api/controller/test_analytics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import flask
import pytest

from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE
Expand Down Expand Up @@ -42,8 +41,8 @@ def test_track_event(self, analytics_fixture: AnalyticsFixture):

patron = db.patron()
patron.neighborhood = "Mars Grid 4810579"
with analytics_fixture.request_context_with_library("/"):
flask.request.patron = patron # type: ignore
with analytics_fixture.request_context_with_library("/") as ctx:
setattr(ctx.request, "patron", patron)
response = analytics_fixture.manager.analytics_controller.track_event(
analytics_fixture.identifier.type,
analytics_fixture.identifier.identifier,
Expand Down
Loading