Skip to content

Commit

Permalink
Add type hints for opds for distributors
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen committed Sep 22, 2023
1 parent dbd1a88 commit b98b810
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 57 deletions.
4 changes: 1 addition & 3 deletions api/circulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,9 +753,7 @@ def release_hold(self, patron: Patron, pin: str, licensepool: LicensePool) -> No
...

@abstractmethod
def update_availability(
self, licensepool: LicensePool
) -> Tuple[LicensePool, bool, bool]:
def update_availability(self, licensepool: LicensePool) -> None:
"""Update availability information for a book."""
...

Expand Down
4 changes: 1 addition & 3 deletions api/odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,7 @@ def update_loan(
_db.delete(loan)
self.update_licensepool(loan.license_pool)

def update_availability( # type: ignore[empty-body]
self, licensepool: LicensePool
) -> Tuple[LicensePool, bool, bool]:
def update_availability(self, licensepool: LicensePool) -> None:
pass


Expand Down
143 changes: 101 additions & 42 deletions api/opds_for_distributors.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,42 @@
from __future__ import annotations

import datetime
import json
from typing import Type
from typing import Any, Callable, Dict, Generator, List, Set, Tuple, Type

import feedparser
from flask_babel import lazy_gettext as _
from requests import Response

from api.selftest import HasCollectionSelfTests
from core.coverage import CoverageFailure
from core.integration.base import HasLibraryIntegrationConfiguration
from core.integration.settings import BaseSettings, ConfigurationFormItem, FormField
from core.metadata_layer import FormatData, TimestampData
from core.metadata_layer import CirculationData, FormatData, TimestampData
from core.model import (
Collection,
Credential,
DeliveryMechanism,
Edition,
ExternalIntegration,
Hyperlink,
Identifier,
LicensePool,
LicensePoolDeliveryMechanism,
Loan,
Patron,
RightsStatus,
Session,
Work,
get_one,
)
from core.opds_import import BaseOPDSImporterSettings, OPDSImporter, OPDSImportMonitor
from core.selftest import SelfTestResult
from core.util.datetime_helpers import utc_now
from core.util.http import HTTP
from core.util.string_helpers import base64

from .circulation import BaseCirculationAPI, FulfillmentInfo, LoanInfo
from .circulation import BaseCirculationAPI, FulfillmentInfo, HoldInfo, LoanInfo
from .circulation_exceptions import *


Expand Down Expand Up @@ -81,20 +90,22 @@ class OPDSForDistributorsAPI(
}

@classmethod
def settings_class(cls) -> Type[BaseSettings]:
def settings_class(cls) -> Type[OPDSForDistributorsSettings]:
return OPDSForDistributorsSettings

@classmethod
def library_settings_class(cls):
def library_settings_class(cls) -> Type[OPDSForDistributorsLibrarySettings]:
return OPDSForDistributorsLibrarySettings

def description(self):
return self.DESCRIPTION
@classmethod
def description(cls) -> str:
return cls.DESCRIPTION # type: ignore[no-any-return]

def label(self):
return self.NAME
@classmethod
def label(cls) -> str:
return cls.NAME

def __init__(self, _db, collection):
def __init__(self, _db: Session, collection: Collection):
super().__init__(_db, collection)
self.external_integration_id = collection.external_integration.id

Expand All @@ -103,24 +114,27 @@ def __init__(self, _db, collection):
self.username = config.username
self.password = config.password
self.feed_url = collection.external_account_id
self.auth_url = None

@property
def collection(self):
return Collection.by_id(self._db, id=self.collection_id)
self.auth_url: Optional[str] = None

def external_integration(self, _db):
def external_integration(self, _db: Session) -> Optional[ExternalIntegration]:
return get_one(_db, ExternalIntegration, id=self.external_integration_id)

def _run_self_tests(self, _db):
def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]:
"""Try to get a token."""
yield self.run_test("Negotiate a fulfillment token", self._get_token, _db)

def _request_with_timeout(self, method, url, *args, **kwargs):
def _request_with_timeout(
self, method: str, url: Optional[str], *args: Any, **kwargs: Any
) -> Response:
"""Wrapper around HTTP.request_with_timeout to be overridden for tests."""
if url is None:
name = self.collection.name if self.collection else "unknown"
raise LibraryAuthorizationFailedException(
f"No URL provided to request_with_timeout for collection: {name}/{self.collection_id}."
)
return HTTP.request_with_timeout(method, url, *args, **kwargs)

def _get_token(self, _db) -> Credential:
def _get_token(self, _db: Session) -> Credential:
# If this is the first time we're getting a token, we
# need to find the authenticate url in the OPDS
# authentication document.
Expand Down Expand Up @@ -212,7 +226,12 @@ def refresh(credential: Credential) -> None:
refresher_method=refresh,
)

def can_fulfill_without_loan(self, patron, licensepool, lpdm):
def can_fulfill_without_loan(
self,
patron: Optional[Patron],
pool: LicensePool,
lpdm: LicensePoolDeliveryMechanism,
) -> bool:
"""Since OPDS For Distributors delivers books to the library rather
than creating loans, any book can be fulfilled without
identifying the patron, assuming the library's policies
Expand All @@ -229,7 +248,7 @@ def can_fulfill_without_loan(self, patron, licensepool, lpdm):
return True
return False

def checkin(self, patron, pin, licensepool):
def checkin(self, patron: Patron, pin: str, licensepool: LicensePool) -> None:
# Delete the patron's loan for this licensepool.
_db = Session.object_session(patron)
try:
Expand All @@ -244,7 +263,13 @@ def checkin(self, patron, pin, licensepool):
# The patron didn't have this book checked out.
pass

def checkout(self, patron, pin, licensepool, internal_format):
def checkout(
self,
patron: Patron,
pin: str,
licensepool: LicensePool,
internal_format: Optional[str],
) -> LoanInfo:
now = utc_now()
return LoanInfo(
licensepool.collection,
Expand All @@ -255,7 +280,15 @@ def checkout(self, patron, pin, licensepool, internal_format):
end_date=None,
)

def fulfill(self, patron, pin, licensepool, internal_format, **kwargs):
def fulfill(
self,
patron: Patron,
pin: str,
licensepool: LicensePool,
internal_format: Optional[str] = None,
part: Optional[str] = None,
fulfill_part_url: Optional[Callable[[Optional[str]], str]] = None,
) -> FulfillmentInfo:
"""Retrieve a bearer token that can be used to download the book.
:param kwargs: A container for arguments to fulfill()
Expand All @@ -282,7 +315,7 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs):
# Build a application/vnd.librarysimplified.bearer-token
# document using information from the credential.
now = utc_now()
expiration = int((credential.expires - now).total_seconds())
expiration = int((credential.expires - now).total_seconds()) # type: ignore[operator]
token_document = dict(
token_type="Bearer",
access_token=credential.credential,
Expand All @@ -304,7 +337,7 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs):
# We couldn't find an acquisition link for this book.
raise CannotFulfill()

def patron_activity(self, patron, pin):
def patron_activity(self, patron: Patron, pin: str) -> List[LoanInfo | HoldInfo]:
# Look up loans for this collection in the database.
_db = Session.object_session(patron)
loans = (
Expand All @@ -325,43 +358,53 @@ def patron_activity(self, patron, pin):
for loan in loans
]

def release_hold(self, patron, pin, licensepool):
def release_hold(self, patron: Patron, pin: str, licensepool: LicensePool) -> None:
# All the books for this integration are available as simultaneous
# use, so there's no need to release a hold.
raise NotImplementedError()

def place_hold(self, patron, pin, licensepool, notification_email_address):
def place_hold(
self,
patron: Patron,
pin: str,
licensepool: LicensePool,
notification_email_address: Optional[str],
) -> HoldInfo:
# All the books for this integration are available as simultaneous
# use, so there's no need to place a hold.
raise NotImplementedError()

def update_availability(self, licensepool):
def update_availability(self, licensepool: LicensePool) -> None:
pass


class OPDSForDistributorsImporter(OPDSImporter):
NAME = OPDSForDistributorsAPI.NAME

@classmethod
def settings_class(cls):
def settings_class(cls) -> Type[OPDSForDistributorsSettings]: # type: ignore[override]
return OPDSForDistributorsSettings

def update_work_for_edition(self, *args, **kwargs):
def update_work_for_edition(
self,
edition: Edition,
is_open_access: bool = False,
) -> tuple[LicensePool | None, Work | None]:
"""After importing a LicensePool, set its availability appropriately.
Books imported through OPDS For Distributors can be designated as
either Open Access (handled elsewhere) or licensed (handled here). For
licensed content, a library that can perform this import is deemed to
have a license for the title and can distribute unlimited copies.
"""
pool, work = super().update_work_for_edition(
*args, is_open_access=False, **kwargs
)
pool.unlimited_access = True
pool, work = super().update_work_for_edition(edition, is_open_access=False)
if pool:
pool.unlimited_access = True

return pool, work

@classmethod
def _add_format_data(cls, circulation):
def _add_format_data(cls, circulation: CirculationData) -> None:
for link in circulation.links:
if (
link.rel == Hyperlink.GENERIC_OPDS_ACQUISITION
Expand All @@ -385,12 +428,20 @@ class OPDSForDistributorsImportMonitor(OPDSImportMonitor):
PROTOCOL = OPDSForDistributorsImporter.NAME
SERVICE_NAME = "OPDS for Distributors Import Monitor"

def __init__(self, _db, collection, import_class, **kwargs):
def __init__(
self,
_db: Session,
collection: Collection,
import_class: Type[OPDSImporter],
**kwargs: Any,
) -> None:
super().__init__(_db, collection, import_class, **kwargs)

self.api = OPDSForDistributorsAPI(_db, collection)

def _get(self, url, headers):
def _get(
self, url: str, headers: Dict[str, str]
) -> Tuple[int, Dict[str, str], bytes]:
"""Make a normal HTTP request for an OPDS feed, but add in an
auth header with the credentials for the collection.
"""
Expand All @@ -409,23 +460,31 @@ class OPDSForDistributorsReaperMonitor(OPDSForDistributorsImportMonitor):
has been removed from the collection.
"""

def __init__(self, _db, collection, import_class, **kwargs):
def __init__(
self,
_db: Session,
collection: Collection,
import_class: Type[OPDSImporter],
**kwargs: Any,
) -> None:
super().__init__(_db, collection, import_class, **kwargs)
self.seen_identifiers = set()
self.seen_identifiers: Set[str] = set()

def feed_contains_new_data(self, feed):
def feed_contains_new_data(self, feed: bytes | str) -> bool:
# Always return True so that the importer will crawl the
# entire feed.
return True

def import_one_feed(self, feed):
def import_one_feed(
self, feed: bytes | str
) -> Tuple[List[Edition], Dict[str, CoverageFailure | List[CoverageFailure]]]:
# Collect all the identifiers in the feed.
parsed_feed = feedparser.parse(feed)
identifiers = [entry.get("id") for entry in parsed_feed.get("entries", [])]
self.seen_identifiers.update(identifiers)
return [], {}

def run_once(self, progress):
def run_once(self, progress: TimestampData) -> TimestampData:
"""Check to see if any identifiers we know about are no longer
present on the remote. If there are any, remove them.
Expand Down
4 changes: 2 additions & 2 deletions api/selftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import ABC
from typing import Iterable, Optional, Tuple, Union
from typing import Generator, Iterable, Optional, Tuple, Union

from sqlalchemy.orm.session import Session

Expand Down Expand Up @@ -157,7 +157,7 @@ def _no_delivery_mechanisms_test(self):
else:
return "All titles in this collection have delivery mechanisms."

def _run_self_tests(self):
def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]:
yield self.run_test(
"Checking for titles that have no delivery mechanisms.",
self._no_delivery_mechanisms_test,
Expand Down
4 changes: 2 additions & 2 deletions core/opds_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ def __init__(
def external_integration(self, _db: Session) -> Optional[ExternalIntegration]:
return get_one(_db, ExternalIntegration, id=self.external_integration_id)

def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]: # type: ignore[override]
def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]:
"""Retrieve the first page of the OPDS feed"""
first_page = self.run_test(
"Retrieve the first page of the OPDS feed (%s)" % self.feed_url,
Expand Down Expand Up @@ -1986,7 +1986,7 @@ def _get_feeds(self) -> Iterable[Tuple[str, bytes]]:
# pick up where we left off.
return reversed(feeds)

def run_once(self, progress_ignore: bool) -> TimestampData:
def run_once(self, progress: TimestampData) -> TimestampData:
feeds = self._get_feeds()
total_imported = 0
total_failures = 0
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ module = [
"api.lcp.hash",
"api.odl",
"api.odl2",
"api.opds_for_distributors",
"core.feed.*",
"core.integration.*",
"core.model.announcements",
Expand Down
Loading

0 comments on commit b98b810

Please sign in to comment.