From 6cee785135b07218153e342b556acab20334251d Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Mon, 2 Oct 2023 11:44:17 -0400 Subject: [PATCH] Fully type hint the OPDS Importer and OPDS2 Importer classes (PP-313) (#1397) * Add type hinting to OPDS importer class * Type hint ODL importer classes * Add type hints for opds for distributors --- api/axis.py | 2 +- api/circulation.py | 6 +- api/lcp/hash.py | 13 +- api/odl.py | 219 ++++++++----- api/odl2.py | 69 +++-- api/opds_for_distributors.py | 155 +++++++--- api/selftest.py | 4 +- core/feed/annotator/admin.py | 2 +- core/metadata_layer.py | 2 +- core/model/datasource.py | 2 +- core/model/identifier.py | 30 +- core/model/licensing.py | 23 +- core/model/measurement.py | 10 +- core/model/patron.py | 3 +- core/model/resource.py | 10 +- core/opds2_import.py | 63 ++-- core/opds_import.py | 378 +++++++++++++++-------- core/util/datetime_helpers.py | 12 +- core/util/xmlparser.py | 38 ++- pyproject.toml | 6 + tests/api/feed/test_library_annotator.py | 8 +- tests/api/test_odl.py | 6 +- tests/api/test_opds.py | 8 +- tests/api/test_opds_for_distributors.py | 17 +- tests/api/test_selftest.py | 3 +- tests/core/test_opds.py | 8 +- tests/core/test_opds_import.py | 39 +-- 27 files changed, 740 insertions(+), 396 deletions(-) diff --git a/api/axis.py b/api/axis.py index 546aa40076..1b0ad69a78 100644 --- a/api/axis.py +++ b/api/axis.py @@ -272,7 +272,7 @@ def _count_activity(): ) # Run the tests defined by HasCollectionSelfTests - for result in super()._run_self_tests(): + for result in super()._run_self_tests(_db): yield result def refresh_bearer_token(self): diff --git a/api/circulation.py b/api/circulation.py index 12fc47a4ff..290405d9d1 100644 --- a/api/circulation.py +++ b/api/circulation.py @@ -435,7 +435,7 @@ def __init__( identifier_type: Optional[str], identifier: Optional[str], start_date: Optional[datetime.datetime], - end_date: datetime.datetime, + end_date: Optional[datetime.datetime], fulfillment_info: Optional[FulfillmentInfo] = None, external_identifier: Optional[str] = None, locked_to: Optional[DeliveryMechanismInfo] = None, @@ -752,9 +752,7 @@ def release_hold(self, patron: Patron, pin: str, licensepool: LicensePool) -> No ... @abstractmethod - def update_availability( - self, licensepool: LicensePool - ) -> Tuple[LicensePool, bool, bool]: + def update_availability(self, licensepool: LicensePool) -> None: """Update availability information for a book.""" ... diff --git a/api/lcp/hash.py b/api/lcp/hash.py index 4569dfa58d..5fd98a75c7 100644 --- a/api/lcp/hash.py +++ b/api/lcp/hash.py @@ -1,5 +1,5 @@ import hashlib -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from enum import Enum from core.exceptions import BaseError @@ -14,20 +14,19 @@ class HashingError(BaseError): """Raised in the case of errors occurred during hashing""" -class Hasher(metaclass=ABCMeta): +class Hasher(ABC): """Base class for all implementations of different hashing algorithms""" - def __init__(self, hashing_algorithm): + def __init__(self, hashing_algorithm: HashingAlgorithm) -> None: """Initializes a new instance of Hasher class :param hashing_algorithm: Hashing algorithm - :type hashing_algorithm: HashingAlgorithm """ self._hashing_algorithm = hashing_algorithm @abstractmethod - def hash(self, value): - raise NotImplementedError() + def hash(self, value: str) -> str: + ... class UniversalHasher(Hasher): @@ -49,5 +48,5 @@ def hash(self, value: str) -> str: class HasherFactory: - def create(self, hashing_algorithm): + def create(self, hashing_algorithm: HashingAlgorithm) -> Hasher: return UniversalHasher(hashing_algorithm) diff --git a/api/odl.py b/api/odl.py index 02e00010a1..4ee86279b2 100644 --- a/api/odl.py +++ b/api/odl.py @@ -1,15 +1,18 @@ +from __future__ import annotations + import binascii import datetime import json import logging import uuid -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import dateutil -import sqlalchemy from flask import url_for from flask_babel import lazy_gettext as _ +from lxml.etree import Element from pydantic import HttpUrl, PositiveInt +from requests import Response from sqlalchemy.sql.expression import or_ from uritemplate import URITemplate @@ -218,20 +221,22 @@ class ODLAPI( ] @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[ODLSettings]: return ODLSettings @classmethod - def library_settings_class(cls): + def library_settings_class(cls) -> Type[ODLLibrarySettings]: return ODLLibrarySettings - def label(self): - return self.NAME + @classmethod + def label(cls) -> str: + return cls.NAME - def description(self): - return self.DESCRIPTION + @classmethod + def description(cls) -> str: + return cls.DESCRIPTION # type: ignore[no-any-return] - def __init__(self, _db, collection): + def __init__(self, _db: Session, collection: Collection) -> None: super().__init__(_db, collection) if collection.protocol != self.NAME: raise ValueError( @@ -252,9 +257,7 @@ def __init__(self, _db, collection): self._credential_factory = LCPCredentialFactory() self._hasher_instance: Optional[Hasher] = None - def external_integration( - self, db: sqlalchemy.orm.session.Session - ) -> ExternalIntegration: + def external_integration(self, db: Session) -> ExternalIntegration: """Return an external integration associated with this object. :param db: Database session @@ -262,7 +265,9 @@ def external_integration( """ return self.collection.external_integration - def internal_format(self, delivery_mechanism): + def internal_format( # type: ignore[override] + self, delivery_mechanism: Optional[LicensePoolDeliveryMechanism] + ) -> Optional[LicensePoolDeliveryMechanism]: """Each consolidated copy is only available in one format, so we don't need a mapping to internal formats. """ @@ -280,23 +285,22 @@ def collection(self) -> Collection: raise ValueError(f"Collection not found: {self.collection_id}") return collection - def _get_hasher(self): + def _get_hasher(self) -> Hasher: """Returns a Hasher instance :return: Hasher instance - :rtype: hash.Hasher """ config = self.configuration() if self._hasher_instance is None: self._hasher_instance = self._hasher_factory.create( - config.encryption_algorithm + config.encryption_algorithm # type: ignore[arg-type] if config.encryption_algorithm else ODLAPIConstants.DEFAULT_ENCRYPTION_ALGORITHM ) return self._hasher_instance - def _get(self, url, headers=None): + def _get(self, url: str, headers: Optional[Dict[str, str]] = None) -> Response: """Make a normal HTTP request, but include an authentication header with the credentials for the collection. """ @@ -309,11 +313,11 @@ def _get(self, url, headers=None): return HTTP.get_with_timeout(url, headers=headers) - def _url_for(self, *args, **kwargs): + def _url_for(self, *args: Any, **kwargs: Any) -> str: """Wrapper around flask's url_for to be overridden for tests.""" return url_for(*args, **kwargs) - def get_license_status_document(self, loan): + def get_license_status_document(self, loan: Loan) -> Dict[str, Any]: """Get the License Status Document for a loan. For a new loan, create a local loan with no external identifier and @@ -360,9 +364,10 @@ def get_license_status_document(self, loan): ) config = self.configuration() - url_template = URITemplate(loan.license.checkout_url) + checkout_url = str(loan.license.checkout_url) + url_template = URITemplate(checkout_url) url = url_template.expand( - id=id, + id=str(id), checkout_id=checkout_id, patron_id=patron_id, expires=expires.isoformat(), @@ -384,9 +389,9 @@ def get_license_status_document(self, loan): raise BadResponseException( url, "License Status Document had an unknown status value." ) - return status_doc + return status_doc # type: ignore[no-any-return] - def checkin(self, patron, pin, licensepool): + def checkin(self, patron: Patron, pin: str, licensepool: LicensePool) -> bool: # type: ignore[override] """Return a loan early.""" _db = Session.object_session(patron) @@ -397,10 +402,10 @@ def checkin(self, patron, pin, licensepool): ) if loan.count() < 1: raise NotCheckedOut() - loan = loan.one() - return self._checkin(loan) + loan_result = loan.one() + return self._checkin(loan_result) - def _checkin(self, loan): + def _checkin(self, loan: Loan) -> bool: _db = Session.object_session(loan) doc = self.get_license_status_document(loan) status = doc.get("status") @@ -427,7 +432,7 @@ def _checkin(self, loan): # must be returned through the DRM system. If that's true, the # app will already be doing that on its own, so we'll silently # do nothing. - return + return False # Hit the distributor's return link. self._get(return_url) @@ -439,12 +444,18 @@ def _checkin(self, loan): # However, it might be because the loan has already been fulfilled # and must be returned through the DRM system, which the app will # do on its own, so we can ignore the problem. - loan = get_one(_db, Loan, id=loan.id) - if loan: - return + new_loan = get_one(_db, Loan, id=loan.id) + if new_loan: + return False return True - def checkout(self, patron, pin, licensepool, internal_format): + def checkout( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + internal_format: Optional[str], + ) -> LoanInfo: """Create a new loan.""" _db = Session.object_session(patron) @@ -457,18 +468,20 @@ def checkout(self, patron, pin, licensepool, internal_format): raise AlreadyCheckedOut() hold = get_one(_db, Hold, patron=patron, license_pool_id=licensepool.id) - loan = self._checkout(patron, licensepool, hold) + loan_obj = self._checkout(patron, licensepool, hold) return LoanInfo( licensepool.collection, licensepool.data_source.name, licensepool.identifier.type, licensepool.identifier.identifier, - loan.start, - loan.end, - external_identifier=loan.external_identifier, + loan_obj.start, + loan_obj.end, + external_identifier=loan_obj.external_identifier, ) - def _checkout(self, patron: Patron, licensepool, hold=None): + def _checkout( + self, patron: Patron, licensepool: LicensePool, hold: Optional[Hold] = None + ) -> Loan: _db = Session.object_session(patron) if not any(l for l in licensepool.licenses if not l.is_inactive): @@ -483,7 +496,9 @@ def _checkout(self, patron: Patron, licensepool, hold=None): # If there's a holds queue, the patron must have a non-expired hold # with position 0 to check out the book. if ( - not hold or hold.position > 0 or (hold.end and hold.end < utc_now()) + not hold + or (hold.position and hold.position > 0) + or (hold.end and hold.end < utc_now()) ) and licensepool.licenses_available < 1: raise NoAvailableCopies() @@ -534,27 +549,28 @@ def _checkout(self, patron: Patron, licensepool, hold=None): self.update_licensepool(licensepool) return loan - def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): - """Get the actual resource file to the patron. - - :param kwargs: A container for arguments to fulfill() - which are not relevant to this vendor. - - :return: a FulfillmentInfo object. - """ + def fulfill( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + internal_format: Optional[str] = None, + part: Optional[str] = None, + fulfill_part_url: Optional[Callable[[Optional[str]], str]] = None, + ) -> FulfillmentInfo: + """Get the actual resource file to the patron.""" _db = Session.object_session(patron) loan = ( _db.query(Loan) .filter(Loan.patron == patron) .filter(Loan.license_pool_id == licensepool.id) - ) - loan = loan.one() + ).one() return self._fulfill(loan, internal_format) @staticmethod def _find_content_link_and_type( - links: List[Dict], + links: List[Dict[str, str]], drm_scheme: Optional[str], ) -> Tuple[Optional[str], Optional[str]]: """Find a content link with the type information corresponding to the selected delivery mechanism. @@ -647,7 +663,7 @@ def _count_holds_before(self, holdinfo: HoldInfo, pool: LicensePool) -> int: .count() ) - def _update_hold_data(self, hold: Hold): + def _update_hold_data(self, hold: Hold) -> None: pool: LicensePool = hold.license_pool holdinfo = HoldInfo( pool.collection, @@ -665,7 +681,7 @@ def _update_hold_data(self, hold: Hold): def _update_hold_end_date( self, holdinfo: HoldInfo, pool: LicensePool, library: Library - ): + ) -> None: _db = Session.object_session(pool) # First make sure the hold position is up-to-date, since we'll @@ -751,7 +767,7 @@ def _update_hold_end_date( days=default_reservation_period ) - def _update_hold_position(self, holdinfo: HoldInfo, pool: LicensePool): + def _update_hold_position(self, holdinfo: HoldInfo, pool: LicensePool) -> None: _db = Session.object_session(pool) loans_count = ( _db.query(Loan) @@ -774,7 +790,7 @@ def _update_hold_position(self, holdinfo: HoldInfo, pool: LicensePool): # Add 1 since position 0 indicates the hold is ready. holdinfo.hold_position = holds_count + 1 - def update_licensepool(self, licensepool: LicensePool): + def update_licensepool(self, licensepool: LicensePool) -> None: # Update the pool and the next holds in the queue when a license is reserved. licensepool.update_availability_from_licenses( analytics=self.analytics, @@ -786,11 +802,17 @@ def update_licensepool(self, licensepool: LicensePool): # This hold just got a reserved license. self._update_hold_data(hold) - def place_hold(self, patron, pin, licensepool, notification_email_address): + def place_hold( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + notification_email_address: Optional[str], + ) -> HoldInfo: """Create a new hold.""" return self._place_hold(patron, licensepool) - def _place_hold(self, patron, licensepool): + def _place_hold(self, patron: Patron, licensepool: LicensePool) -> HoldInfo: _db = Session.object_session(patron) # Make sure pool info is updated. @@ -810,14 +832,19 @@ def _place_hold(self, patron, licensepool): if hold is not None: raise AlreadyOnHold() - licensepool.patrons_in_hold_queue += 1 + patrons_in_hold_queue = ( + licensepool.patrons_in_hold_queue + if licensepool.patrons_in_hold_queue + else 0 + ) + licensepool.patrons_in_hold_queue = patrons_in_hold_queue + 1 holdinfo = HoldInfo( licensepool.collection, licensepool.data_source.name, licensepool.identifier.type, licensepool.identifier.identifier, utc_now(), - 0, + None, 0, ) library = patron.library @@ -825,7 +852,7 @@ def _place_hold(self, patron, licensepool): return holdinfo - def release_hold(self, patron, pin, licensepool): + def release_hold(self, patron: Patron, pin: str, licensepool: LicensePool) -> None: """Cancel a hold.""" _db = Session.object_session(patron) @@ -837,9 +864,9 @@ def release_hold(self, patron, pin, licensepool): ) if not hold: raise NotOnHold() - return self._release_hold(hold) + self._release_hold(hold) - def _release_hold(self, hold): + def _release_hold(self, hold: Hold) -> Literal[True]: # If the book was ready and the patron revoked the hold instead # of checking it out, but no one else had the book on hold, the # book is now available for anyone to check out. If someone else @@ -852,7 +879,7 @@ def _release_hold(self, hold): self.update_licensepool(licensepool) return True - def patron_activity(self, patron, pin): + def patron_activity(self, patron: Patron, pin: str) -> List[LoanInfo | HoldInfo]: """Look up non-expired loans for this collection in the database.""" _db = Session.object_session(patron) loans = ( @@ -904,7 +931,9 @@ def patron_activity(self, patron, pin): for hold in remaining_holds ] - def update_loan(self, loan, status_doc=None): + def update_loan( + self, loan: Loan, status_doc: Optional[Dict[str, Any]] = None + ) -> None: """Check a loan's status, and if it is no longer active, delete the loan and update its pool's availability. """ @@ -918,7 +947,8 @@ def update_loan(self, loan, status_doc=None): # but if the document came from a notification it hasn't been checked yet. if status not in self.STATUS_VALUES: raise BadResponseException( - "The License Status Document had an unknown status value." + str(loan.license.checkout_url), + "The License Status Document had an unknown status value.", ) if status in [ @@ -937,7 +967,7 @@ def update_loan(self, loan, status_doc=None): _db.delete(loan) self.update_licensepool(loan.license_pool) - def update_availability(self, licensepool): + def update_availability(self, licensepool: LicensePool) -> None: pass @@ -975,11 +1005,13 @@ class ODLImporter(OPDSImporter): } @classmethod - def fetch_license_info(cls, document_link: str, do_get: Callable) -> Optional[dict]: + def fetch_license_info( + cls, document_link: str, do_get: Callable[..., Tuple[int, Any, bytes]] + ) -> Optional[Dict[str, Any]]: status_code, _, response = do_get(document_link, headers={}) if status_code in (200, 201): license_info_document = json.loads(response) - return license_info_document + return license_info_document # type: ignore[no-any-return] else: logging.warning( f"License Info Document is not available. " @@ -990,9 +1022,9 @@ def fetch_license_info(cls, document_link: str, do_get: Callable) -> Optional[di @classmethod def parse_license_info( cls, - license_info_document: dict, + license_info_document: Dict[str, Any], license_info_link: str, - checkout_link: str, + checkout_link: Optional[str], ) -> Optional[LicenseData]: """Check the license's attributes passed as parameters: - if they're correct, turn them into a LicenseData object @@ -1078,11 +1110,11 @@ def parse_license_info( def get_license_data( cls, license_info_link: str, - checkout_link: str, - feed_license_identifier: str, - feed_license_expires: str, - feed_concurrency: int, - do_get: Callable, + checkout_link: Optional[str], + feed_license_identifier: Optional[str], + feed_license_expires: Optional[datetime.datetime], + feed_concurrency: Optional[int], + do_get: Callable[..., Tuple[int, Any, bytes]], ) -> Optional[LicenseData]: license_info_document = cls.fetch_license_info(license_info_link, do_get) @@ -1130,8 +1162,12 @@ def get_license_data( @classmethod def _detail_for_elementtree_entry( - cls, parser, entry_tag, feed_url=None, do_get=None - ): + cls, + parser: OPDSXMLParser, + entry_tag: Element, + feed_url: Optional[str] = None, + do_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + ) -> Dict[str, Any]: do_get = do_get or Representation.cautious_http_get # TODO: Review for consistency when updated ODL spec is ready. @@ -1152,7 +1188,7 @@ def _detail_for_elementtree_entry( # By default, dcterms:format includes the media type of a # DRM-free resource. content_type = full_content_type - drm_schemes = [] + drm_schemes: List[str | None] = [] # But it may instead describe an audiobook protected with # the Feedbooks access-control scheme. @@ -1206,11 +1242,12 @@ def _detail_for_elementtree_entry( concurrent_checkouts = subtag(terms[0], "odl:concurrent_checkouts") expires = subtag(terms[0], "odl:expires") - if concurrent_checkouts is not None: - concurrent_checkouts = int(concurrent_checkouts) - - if expires is not None: - expires = to_utc(dateutil.parser.parse(expires)) + concurrent_checkouts_int = ( + int(concurrent_checkouts) if concurrent_checkouts is not None else None + ) + expires_datetime = ( + to_utc(dateutil.parser.parse(expires)) if expires is not None else None + ) if not odl_status_link: parsed_license = None @@ -1219,8 +1256,8 @@ def _detail_for_elementtree_entry( odl_status_link, checkout_link, identifier, - expires, - concurrent_checkouts, + expires_datetime, + concurrent_checkouts_int, do_get, ) @@ -1248,7 +1285,13 @@ class ODLImportMonitor(OPDSImportMonitor): PROTOCOL = ODLImporter.NAME SERVICE_NAME = "ODL Import Monitor" - def __init__(self, _db, collection, import_class, **import_class_kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + import_class: Type[OPDSImporter], + **import_class_kwargs: Any, + ): # Always force reimport ODL collections to get up to date license information super().__init__( _db, collection, import_class, force_reimport=True, **import_class_kwargs @@ -1262,11 +1305,17 @@ class ODLHoldReaper(CollectionMonitor): SERVICE_NAME = "ODL Hold Reaper" PROTOCOL = ODLAPI.NAME - def __init__(self, _db, collection=None, api=None, **kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + api: Optional[ODLAPI] = None, + **kwargs: Any, + ): super().__init__(_db, collection, **kwargs) self.api = api or ODLAPI(_db, collection) - def run_once(self, progress): + def run_once(self, progress: TimestampData) -> TimestampData: # Find holds that have expired. expired_holds = ( self._db.query(Hold) diff --git a/api/odl2.py b/api/odl2.py index b4b8bc0dbd..946d71eb63 100644 --- a/api/odl2.py +++ b/api/odl2.py @@ -1,10 +1,11 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type from flask_babel import lazy_gettext as _ from pydantic import PositiveInt +from sqlalchemy.orm import Session from webpub_manifest_parser.odl import ODLFeedParserFactory from webpub_manifest_parser.opds2.registry import OPDS2LinkRelationsRegistry @@ -23,7 +24,12 @@ from core.util.datetime_helpers import to_utc if TYPE_CHECKING: - from core.model.patron import Patron + from webpub_manifest_parser.core.ast import Metadata + from webpub_manifest_parser.opds2.ast import OPDS2Feed, OPDS2Publication + + from api.circulation import HoldInfo + from core.model import Collection, Identifier, LicensePool + from core.model.patron import Hold, Loan, Patron class ODL2Settings(ODLSettings): @@ -71,16 +77,18 @@ class ODL2API(ODLAPI): NAME = ExternalIntegration.ODL2 @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[ODL2Settings]: return ODL2Settings - def __init__(self, _db, collection): + def __init__(self, _db: Session, collection: Collection) -> None: super().__init__(_db, collection) config = self.configuration() - self.loan_limit = config.loan_limit - self.hold_limit = config.hold_limit + self.loan_limit = config.loan_limit # type: ignore[attr-defined] + self.hold_limit = config.hold_limit # type: ignore[attr-defined] - def _checkout(self, patron: Patron, licensepool, hold=None): + def _checkout( + self, patron: Patron, licensepool: LicensePool, hold: Optional[Hold] = None + ) -> Loan: # If the loan limit is not None or 0 if self.loan_limit: loans = list( @@ -93,7 +101,7 @@ def _checkout(self, patron: Patron, licensepool, hold=None): raise PatronLoanLimitReached(limit=self.loan_limit) return super()._checkout(patron, licensepool, hold) - def _place_hold(self, patron: Patron, licensepool): + def _place_hold(self, patron: Patron, licensepool: LicensePool) -> HoldInfo: # If the hold limit is not None or 0 if self.hold_limit: holds = list( @@ -117,19 +125,19 @@ class ODL2Importer(OPDS2Importer, HasExternalIntegration): NAME = ODL2API.NAME @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[ODL2Settings]: # type: ignore[override] return ODL2Settings def __init__( self, - db, - collection, - parser=None, - data_source_name=None, - identifier_mapping=None, - http_get=None, - content_modifier=None, - map_from_collection=None, + db: Session, + collection: Collection, + parser: Optional[RWPMManifestParser] = None, + data_source_name: str | None = None, + identifier_mapping: Dict[Identifier, Identifier] | None = None, + http_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + content_modifier: Optional[Callable[..., None]] = None, + map_from_collection: Optional[bool] = None, ): """Initialize a new instance of ODL2Importer class. @@ -173,20 +181,19 @@ def __init__( ) self._logger = logging.getLogger(__name__) - def _extract_publication_metadata(self, feed, publication, data_source_name): + def _extract_publication_metadata( + self, + feed: OPDS2Feed, + publication: OPDS2Publication, + data_source_name: Optional[str], + ) -> Metadata: """Extract a Metadata object from webpub-manifest-parser's publication. :param publication: Feed object - :type publication: opds2_ast.OPDS2Feed - :param publication: Publication object - :type publication: opds2_ast.OPDS2Publication - :param data_source_name: Data source's name - :type data_source_name: str :return: Publication's metadata - :rtype: Metadata """ metadata = super()._extract_publication_metadata( feed, publication, data_source_name @@ -195,7 +202,7 @@ def _extract_publication_metadata(self, feed, publication, data_source_name): licenses = [] medium = None - skipped_license_formats = self.configuration().skipped_license_formats + skipped_license_formats = self.configuration().skipped_license_formats # type: ignore[attr-defined] if skipped_license_formats: skipped_license_formats = set(skipped_license_formats) @@ -251,6 +258,7 @@ def _extract_publication_metadata(self, feed, publication, data_source_name): if not medium: medium = Edition.medium_from_media_type(license_format) + drm_schemes: List[str | None] if license_format in ODLImporter.LICENSE_FORMATS: # Special case to handle DeMarque audiobooks which include the protection # in the content type. When we see a license format of @@ -291,9 +299,6 @@ def _extract_publication_metadata(self, feed, publication, data_source_name): return metadata - def external_integration(self, db): - return self.collection.external_integration - class ODL2ImportMonitor(OPDS2ImportMonitor): """Import information from an ODL feed.""" @@ -301,7 +306,13 @@ class ODL2ImportMonitor(OPDS2ImportMonitor): PROTOCOL = ODL2Importer.NAME SERVICE_NAME = "ODL 2.x Import Monitor" - def __init__(self, _db, collection, import_class, **import_class_kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + import_class: Type[ODL2Importer], + **import_class_kwargs: Any, + ) -> None: # Always force reimport ODL collections to get up to date license information super().__init__( _db, collection, import_class, force_reimport=True, **import_class_kwargs diff --git a/api/opds_for_distributors.py b/api/opds_for_distributors.py index 88d5f753f8..be535dada3 100644 --- a/api/opds_for_distributors.py +++ b/api/opds_for_distributors.py @@ -1,10 +1,28 @@ +from __future__ import annotations + import datetime import json -from typing import Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Tuple, + Type, +) import feedparser from flask_babel import lazy_gettext as _ +from api.circulation import BaseCirculationAPI, FulfillmentInfo, LoanInfo +from api.circulation_exceptions import ( + CannotFulfill, + LibraryAuthorizationFailedException, +) from api.selftest import HasCollectionSelfTests from core.integration.base import HasLibraryIntegrationConfiguration from core.integration.settings import BaseSettings, ConfigurationFormItem, FormField @@ -27,8 +45,14 @@ from core.util.http import HTTP from core.util.string_helpers import base64 -from .circulation import BaseCirculationAPI, FulfillmentInfo, LoanInfo -from .circulation_exceptions import * +if TYPE_CHECKING: + from requests import Response + + from api.circulation import HoldInfo + from core.coverage import CoverageFailure + from core.metadata_layer import CirculationData + from core.model import Edition, LicensePoolDeliveryMechanism, Patron, Work + from core.selftest import SelfTestResult class OPDSForDistributorsSettings(BaseOPDSImporterSettings): @@ -81,20 +105,22 @@ class OPDSForDistributorsAPI( } @classmethod - def settings_class(cls) -> Type[BaseSettings]: + def settings_class(cls) -> Type[OPDSForDistributorsSettings]: return OPDSForDistributorsSettings @classmethod - def library_settings_class(cls): + def library_settings_class(cls) -> Type[OPDSForDistributorsLibrarySettings]: return OPDSForDistributorsLibrarySettings - def description(self): - return self.DESCRIPTION + @classmethod + def description(cls) -> str: + return cls.DESCRIPTION # type: ignore[no-any-return] - def label(self): - return self.NAME + @classmethod + def label(cls) -> str: + return cls.NAME - def __init__(self, _db, collection): + def __init__(self, _db: Session, collection: Collection): super().__init__(_db, collection) self.external_integration_id = collection.external_integration.id @@ -103,24 +129,27 @@ def __init__(self, _db, collection): self.username = config.username self.password = config.password self.feed_url = collection.external_account_id - self.auth_url = None - - @property - def collection(self): - return Collection.by_id(self._db, id=self.collection_id) + self.auth_url: Optional[str] = None - def external_integration(self, _db): + def external_integration(self, _db: Session) -> Optional[ExternalIntegration]: return get_one(_db, ExternalIntegration, id=self.external_integration_id) - def _run_self_tests(self, _db): + def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]: """Try to get a token.""" yield self.run_test("Negotiate a fulfillment token", self._get_token, _db) - def _request_with_timeout(self, method, url, *args, **kwargs): + def _request_with_timeout( + self, method: str, url: Optional[str], *args: Any, **kwargs: Any + ) -> Response: """Wrapper around HTTP.request_with_timeout to be overridden for tests.""" + if url is None: + name = self.collection.name if self.collection else "unknown" + raise LibraryAuthorizationFailedException( + f"No URL provided to request_with_timeout for collection: {name}/{self.collection_id}." + ) return HTTP.request_with_timeout(method, url, *args, **kwargs) - def _get_token(self, _db) -> Credential: + def _get_token(self, _db: Session) -> Credential: # If this is the first time we're getting a token, we # need to find the authenticate url in the OPDS # authentication document. @@ -212,7 +241,12 @@ def refresh(credential: Credential) -> None: refresher_method=refresh, ) - def can_fulfill_without_loan(self, patron, licensepool, lpdm): + def can_fulfill_without_loan( + self, + patron: Optional[Patron], + pool: LicensePool, + lpdm: LicensePoolDeliveryMechanism, + ) -> bool: """Since OPDS For Distributors delivers books to the library rather than creating loans, any book can be fulfilled without identifying the patron, assuming the library's policies @@ -229,7 +263,7 @@ def can_fulfill_without_loan(self, patron, licensepool, lpdm): return True return False - def checkin(self, patron, pin, licensepool): + def checkin(self, patron: Patron, pin: str, licensepool: LicensePool) -> None: # Delete the patron's loan for this licensepool. _db = Session.object_session(patron) try: @@ -244,7 +278,13 @@ def checkin(self, patron, pin, licensepool): # The patron didn't have this book checked out. pass - def checkout(self, patron, pin, licensepool, internal_format): + def checkout( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + internal_format: Optional[str], + ) -> LoanInfo: now = utc_now() return LoanInfo( licensepool.collection, @@ -255,7 +295,15 @@ def checkout(self, patron, pin, licensepool, internal_format): end_date=None, ) - def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): + def fulfill( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + internal_format: Optional[str] = None, + part: Optional[str] = None, + fulfill_part_url: Optional[Callable[[Optional[str]], str]] = None, + ) -> FulfillmentInfo: """Retrieve a bearer token that can be used to download the book. :param kwargs: A container for arguments to fulfill() @@ -282,7 +330,7 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): # Build a application/vnd.librarysimplified.bearer-token # document using information from the credential. now = utc_now() - expiration = int((credential.expires - now).total_seconds()) + expiration = int((credential.expires - now).total_seconds()) # type: ignore[operator] token_document = dict( token_type="Bearer", access_token=credential.credential, @@ -304,7 +352,7 @@ def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): # We couldn't find an acquisition link for this book. raise CannotFulfill() - def patron_activity(self, patron, pin): + def patron_activity(self, patron: Patron, pin: str) -> List[LoanInfo | HoldInfo]: # Look up loans for this collection in the database. _db = Session.object_session(patron) loans = ( @@ -325,17 +373,23 @@ def patron_activity(self, patron, pin): for loan in loans ] - def release_hold(self, patron, pin, licensepool): + def release_hold(self, patron: Patron, pin: str, licensepool: LicensePool) -> None: # All the books for this integration are available as simultaneous # use, so there's no need to release a hold. raise NotImplementedError() - def place_hold(self, patron, pin, licensepool, notification_email_address): + def place_hold( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + notification_email_address: Optional[str], + ) -> HoldInfo: # All the books for this integration are available as simultaneous # use, so there's no need to place a hold. raise NotImplementedError() - def update_availability(self, licensepool): + def update_availability(self, licensepool: LicensePool) -> None: pass @@ -343,10 +397,14 @@ class OPDSForDistributorsImporter(OPDSImporter): NAME = OPDSForDistributorsAPI.NAME @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[OPDSForDistributorsSettings]: # type: ignore[override] return OPDSForDistributorsSettings - def update_work_for_edition(self, *args, **kwargs): + def update_work_for_edition( + self, + edition: Edition, + is_open_access: bool = False, + ) -> tuple[LicensePool | None, Work | None]: """After importing a LicensePool, set its availability appropriately. Books imported through OPDS For Distributors can be designated as @@ -354,17 +412,14 @@ def update_work_for_edition(self, *args, **kwargs): licensed content, a library that can perform this import is deemed to have a license for the title and can distribute unlimited copies. """ - pool, work = super().update_work_for_edition( - *args, is_open_access=False, **kwargs - ) - + pool, work = super().update_work_for_edition(edition, is_open_access=False) if pool: pool.unlimited_access = True return pool, work @classmethod - def _add_format_data(cls, circulation): + def _add_format_data(cls, circulation: CirculationData) -> None: for link in circulation.links: if ( link.rel == Hyperlink.GENERIC_OPDS_ACQUISITION @@ -388,12 +443,20 @@ class OPDSForDistributorsImportMonitor(OPDSImportMonitor): PROTOCOL = OPDSForDistributorsImporter.NAME SERVICE_NAME = "OPDS for Distributors Import Monitor" - def __init__(self, _db, collection, import_class, **kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + import_class: Type[OPDSImporter], + **kwargs: Any, + ) -> None: super().__init__(_db, collection, import_class, **kwargs) self.api = OPDSForDistributorsAPI(_db, collection) - def _get(self, url, headers): + def _get( + self, url: str, headers: Dict[str, str] + ) -> Tuple[int, Dict[str, str], bytes]: """Make a normal HTTP request for an OPDS feed, but add in an auth header with the credentials for the collection. """ @@ -412,23 +475,31 @@ class OPDSForDistributorsReaperMonitor(OPDSForDistributorsImportMonitor): has been removed from the collection. """ - def __init__(self, _db, collection, import_class, **kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + import_class: Type[OPDSImporter], + **kwargs: Any, + ) -> None: super().__init__(_db, collection, import_class, **kwargs) - self.seen_identifiers = set() + self.seen_identifiers: Set[str] = set() - def feed_contains_new_data(self, feed): + def feed_contains_new_data(self, feed: bytes | str) -> bool: # Always return True so that the importer will crawl the # entire feed. return True - def import_one_feed(self, feed): + def import_one_feed( + self, feed: bytes | str + ) -> Tuple[List[Edition], Dict[str, CoverageFailure | List[CoverageFailure]]]: # Collect all the identifiers in the feed. parsed_feed = feedparser.parse(feed) identifiers = [entry.get("id") for entry in parsed_feed.get("entries", [])] self.seen_identifiers.update(identifiers) return [], {} - def run_once(self, progress): + def run_once(self, progress: TimestampData) -> TimestampData: """Check to see if any identifiers we know about are no longer present on the remote. If there are any, remove them. diff --git a/api/selftest.py b/api/selftest.py index 3f26f58ff9..d86de07fe3 100644 --- a/api/selftest.py +++ b/api/selftest.py @@ -2,7 +2,7 @@ import logging from abc import ABC -from typing import Iterable, Optional, Tuple, Union +from typing import Generator, Iterable, Optional, Tuple, Union from sqlalchemy.orm.session import Session @@ -157,7 +157,7 @@ def _no_delivery_mechanisms_test(self): else: return "All titles in this collection have delivery mechanisms." - def _run_self_tests(self): + def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]: yield self.run_test( "Checking for titles that have no delivery mechanisms.", self._no_delivery_mechanisms_test, diff --git a/core/feed/annotator/admin.py b/core/feed/annotator/admin.py index 8b5e903e29..27da250676 100644 --- a/core/feed/annotator/admin.py +++ b/core/feed/annotator/admin.py @@ -28,7 +28,7 @@ def annotate_work_entry( # Find staff rating and add a tag for it. for measurement in identifier.measurements: if ( - measurement.data_source.name == DataSource.LIBRARY_STAFF # type: ignore[attr-defined] + measurement.data_source.name == DataSource.LIBRARY_STAFF and measurement.is_most_recent and measurement.value is not None ): diff --git a/core/metadata_layer.py b/core/metadata_layer.py index e73ae1dc8b..4fe5e06edf 100644 --- a/core/metadata_layer.py +++ b/core/metadata_layer.py @@ -523,7 +523,7 @@ class LicenseData(LicenseFunctions): def __init__( self, identifier: str, - checkout_url: str, + checkout_url: Optional[str], status_url: str, status: LicenseStatus, checkouts_available: int, diff --git a/core/model/datasource.py b/core/model/datasource.py index d4bf6ecc5b..cc75bd275d 100644 --- a/core/model/datasource.py +++ b/core/model/datasource.py @@ -72,7 +72,7 @@ class DataSource(Base, HasSessionCache, DataSourceConstants): # One DataSource can generate many Measurements. measurements: Mapped[List[Measurement]] = relationship( - "Measurement", backref="data_source" + "Measurement", back_populates="data_source" ) # One DataSource can provide many Classifications. diff --git a/core/model/identifier.py b/core/model/identifier.py index 509d2d5a1a..295053181b 100644 --- a/core/model/identifier.py +++ b/core/model/identifier.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict from functools import total_ordering -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, overload from urllib.parse import quote, unquote import isbnlib @@ -391,16 +391,16 @@ def valid_as_foreign_identifier(cls, type, id): return True @property - def urn(self): - identifier_text = quote(self.identifier) + def urn(self) -> str: + identifier_text = quote(self.identifier or "") if self.type == Identifier.ISBN: return self.ISBN_URN_SCHEME_PREFIX + identifier_text elif self.type == Identifier.URI: - return self.identifier + return self.identifier or "" elif self.type == Identifier.GUTENBERG_ID: return self.GUTENBERG_URN_SCHEME_PREFIX + identifier_text else: - identifier_type = quote(self.type) + identifier_type = quote(self.type or "") return self.URN_SCHEME_PREFIX + "{}/{}".format( identifier_type, identifier_text ) @@ -561,6 +561,26 @@ def _parse_urn( return cls.for_foreign_id(_db, identifier_type, identifier_string) + @classmethod + @overload + def parse_urn( + cls, + _db: Session, + identifier_string: str, + must_support_license_pools: bool = False, + ) -> tuple[Identifier, bool]: + ... + + @classmethod + @overload + def parse_urn( + cls, + _db: Session, + identifier_string: str | None, + must_support_license_pools: bool = False, + ) -> tuple[Identifier | None, bool | None]: + ... + @classmethod def parse_urn( cls, diff --git a/core/model/licensing.py b/core/model/licensing.py index 93ab351e59..3511d3b8bd 100644 --- a/core/model/licensing.py +++ b/core/model/licensing.py @@ -5,7 +5,7 @@ import datetime import logging from enum import Enum as PythonEnum -from typing import TYPE_CHECKING, List, Literal, Tuple, overload +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, overload from sqlalchemy import Boolean, Column, DateTime from sqlalchemy import Enum as AlchemyEnum @@ -139,12 +139,12 @@ class License(Base, LicenseFunctions): # One License can have many Loans. loans: Mapped[List[Loan]] = relationship( - "Loan", backref="license", cascade="all, delete-orphan" + "Loan", back_populates="license", cascade="all, delete-orphan" ) __table_args__ = (UniqueConstraint("identifier", "license_pool_id"),) - def loan_to(self, patron: Patron, **kwargs): + def loan_to(self, patron: Patron, **kwargs) -> Tuple[Loan, bool]: loan, is_new = self.license_pool.loan_to(patron, **kwargs) loan.license = self return loan, is_new @@ -1021,7 +1021,7 @@ def loan_to( end=None, fulfillment=None, external_identifier=None, - ): + ) -> Tuple[Loan, bool]: _db = Session.object_session(patron) kwargs = dict(start=start or utc_now(), end=end) loan, is_new = get_one_or_create( @@ -1067,7 +1067,7 @@ def on_hold_to( hold.external_identifier = external_identifier return hold, new - def best_available_license(self): + def best_available_license(self) -> License | None: """Determine the next license that should be lent out for this pool. Time-limited licenses and perpetual licenses are the best. It doesn't matter which @@ -1084,7 +1084,7 @@ def best_available_license(self): The worst option would be pay-per-use, but we don't yet support any distributors that offer that model. """ - best = None + best: Optional[License] = None now = utc_now() for license in self.licenses: @@ -1094,7 +1094,10 @@ def best_available_license(self): active_loan_count = len( [l for l in license.loans if not l.end or l.end > now] ) - if active_loan_count >= license.checkouts_available: + checkouts_available = ( + license.checkouts_available if license.checkouts_available else 0 + ) + if active_loan_count >= checkouts_available: continue if ( @@ -1103,13 +1106,13 @@ def best_available_license(self): or ( license.is_time_limited and best.is_time_limited - and license.expires < best.expires + and license.expires < best.expires # type: ignore[operator] ) or (license.is_perpetual and not best.is_time_limited) or ( license.is_loan_limited and best.is_loan_limited - and license.checkouts_left > best.checkouts_left + and license.checkouts_left > best.checkouts_left # type: ignore[operator] ) ): best = license @@ -2024,7 +2027,7 @@ def lookup(cls, _db, uri): return status @classmethod - def rights_uri_from_string(cls, rights): + def rights_uri_from_string(cls, rights: str) -> str: rights = rights.lower() if rights == "public domain in the usa.": return RightsStatus.PUBLIC_DOMAIN_USA diff --git a/core/model/measurement.py b/core/model/measurement.py index 0f0e74aead..751fe1a52b 100644 --- a/core/model/measurement.py +++ b/core/model/measurement.py @@ -1,14 +1,19 @@ # Measurement - +from __future__ import annotations import bisect import logging +from typing import TYPE_CHECKING from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, Unicode +from sqlalchemy.orm import Mapped, relationship from . import Base from .constants import DataSourceConstants +if TYPE_CHECKING: + from .datasource import DataSource + class Measurement(Base): """A measurement of some numeric quantity associated with a @@ -711,6 +716,9 @@ class Measurement(Base): # A Measurement always comes from some DataSource. data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True) + data_source: Mapped[DataSource] = relationship( + "DataSource", back_populates="measurements" + ) # The quantity being measured. quantity_measured = Column(Unicode, index=True) diff --git a/core/model/patron.py b/core/model/patron.py index 1d4e27bcd0..f5e8ccf8dd 100644 --- a/core/model/patron.py +++ b/core/model/patron.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from core.model.library import Library - from core.model.licensing import LicensePool, LicensePoolDeliveryMechanism + from core.model.licensing import License, LicensePool, LicensePoolDeliveryMechanism from .devicetokens import DeviceToken @@ -544,6 +544,7 @@ class Loan(Base, LoanAndHoldMixin): # It may also be associated with an individual License if the source # provides information about individual licenses. license_id = Column(Integer, ForeignKey("licenses.id"), index=True, nullable=True) + license: Mapped[License] = relationship("License", back_populates="loans") fulfillment_id = Column(Integer, ForeignKey("licensepooldeliveries.id")) fulfillment: Mapped[Optional[LicensePoolDeliveryMechanism]] = relationship( diff --git a/core/model/resource.py b/core/model/resource.py index 8423e9fa52..475bb2eabd 100644 --- a/core/model/resource.py +++ b/core/model/resource.py @@ -10,7 +10,7 @@ import traceback from hashlib import md5 from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from urllib.parse import quote, urlparse, urlsplit import requests @@ -43,7 +43,7 @@ from .licensing import LicensePoolDeliveryMechanism if TYPE_CHECKING: - from core.model import CachedMARCFile, Work # noqa: autoflake + from core.model import CachedMARCFile class Resource(Base): @@ -1019,12 +1019,14 @@ def headers_to_string(cls, d): return json.dumps(dict(d)) @classmethod - def simple_http_get(cls, url, headers, **kwargs) -> Tuple[int, Any, Any]: + def simple_http_get( + cls, url, headers, **kwargs + ) -> Tuple[int, Dict[str, str], bytes]: """The most simple HTTP-based GET.""" if not "allow_redirects" in kwargs: kwargs["allow_redirects"] = True response = HTTP.get_with_timeout(url, headers=headers, **kwargs) - return response.status_code, response.headers, response.content + return response.status_code, response.headers, response.content # type: ignore[return-value] @classmethod def simple_http_post(cls, url, headers, **kwargs): diff --git a/core/opds2_import.py b/core/opds2_import.py index 39edf52071..f2128e9841 100644 --- a/core/opds2_import.py +++ b/core/opds2_import.py @@ -3,15 +3,16 @@ import logging from datetime import datetime from io import BytesIO, StringIO -from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Tuple, Type from urllib.parse import urljoin, urlparse import sqlalchemy import webpub_manifest_parser.opds2.ast as opds2_ast from flask_babel import lazy_gettext as _ +from sqlalchemy.orm import Session from webpub_manifest_parser.core import ManifestParserFactory, ManifestParserResult from webpub_manifest_parser.core.analyzer import NodeFinder -from webpub_manifest_parser.core.ast import Manifestlike +from webpub_manifest_parser.core.ast import Link, Manifestlike from webpub_manifest_parser.errors import BaseError from webpub_manifest_parser.opds2.registry import ( OPDS2LinkRelationsRegistry, @@ -79,7 +80,7 @@ def __init__(self, manifest_parser_factory: ManifestParserFactory): self._manifest_parser_factory = manifest_parser_factory def parse_manifest( - self, manifest: str | dict | Manifestlike + self, manifest: str | dict[str, Any] | Manifestlike ) -> ManifestParserResult: """Parse the feed into an RPWM-like AST object. @@ -145,25 +146,27 @@ class OPDS2Importer( NEXT_LINK_RELATION: str = "next" @classmethod - def settings_class(self): + def settings_class(cls) -> Type[OPDS2ImporterSettings]: return OPDS2ImporterSettings - def label(self): - return self.NAME + @classmethod + def label(cls) -> str: + return cls.NAME - def description(self): - return self.DESCRIPTION + @classmethod + def description(cls) -> str: + return cls.DESCRIPTION def __init__( self, - db: sqlalchemy.orm.session.Session, + db: Session, collection: Collection, parser: RWPMManifestParser, data_source_name: str | None = None, - identifier_mapping: dict | None = None, - http_get: Callable | None = None, - content_modifier: Callable | None = None, - map_from_collection: dict | None = None, + identifier_mapping: Dict[Identifier, Identifier] | None = None, + http_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + content_modifier: Optional[Callable[..., None]] = None, + map_from_collection: Optional[bool] = None, ): """Initialize a new instance of OPDS2Importer class. @@ -307,20 +310,16 @@ def _extract_contributors( return contributor_metadata_list - def _extract_link(self, link, feed_self_url, default_link_rel=None): + def _extract_link( + self, link: Link, feed_self_url: str, default_link_rel: Optional[str] = None + ) -> LinkData: """Extract a LinkData object from webpub-manifest-parser's link. :param link: webpub-manifest-parser's link - :type link: ast_core.Link - :param feed_self_url: Feed's self URL - :type feed_self_url: str - :param default_link_rel: Default link's relation - :type default_link_rel: Optional[str] :return: Link metadata - :rtype: LinkData """ self._logger.debug(f"Started extracting link metadata from {encode(link)}") @@ -599,7 +598,7 @@ def _extract_publication_metadata( self, feed: opds2_ast.OPDS2Feed, publication: opds2_ast.OPDS2Publication, - data_source_name: str, + data_source_name: Optional[str], ) -> Metadata: """Extract a Metadata object from webpub-manifest-parser's publication. @@ -783,6 +782,8 @@ def external_integration( :param db: Database session :return: External integration associated with this object """ + if self.collection is None: + raise ValueError("Collection is not set") return self.collection.external_integration def integration_configuration(self) -> IntegrationConfiguration: @@ -854,7 +855,7 @@ def _is_open_access_link_( def _record_coverage_failure( self, - failures: dict[str, list[CoverageFailure]], + failures: dict[str, list[CoverageFailure] | CoverageFailure], identifier: Identifier, error_message: str, transient: bool = True, @@ -880,7 +881,7 @@ def _record_coverage_failure( transient=transient, collection=self.collection, ) - failures[identifier.identifier].append(failure) + failures[identifier.identifier].append(failure) # type: ignore[union-attr] return failure @@ -917,11 +918,11 @@ def extract_next_links(self, feed: str | opds2_ast.OPDS2Feed) -> list[str]: next_links = parsed_feed.links.get_by_rel(self.NEXT_LINK_RELATION) next_links = [next_link.href for next_link in next_links] - return next_links + return next_links # type: ignore[no-any-return] def extract_last_update_dates( self, feed: str | opds2_ast.OPDS2Feed - ) -> list[tuple[str, datetime]]: + ) -> list[tuple[Optional[str], Optional[datetime]]]: """Extract last update date of the feed. :param feed: OPDS 2.0 feed @@ -947,13 +948,13 @@ def _parse_feed_links(self, links: list[core_ast.Link]) -> None: if first_or_default(link.rels) == Hyperlink.TOKEN_AUTH: # Save the collection-wide token authentication endpoint auth_setting = ConfigurationSetting.for_externalintegration( - ExternalIntegration.TOKEN_AUTH, self.collection.external_integration + ExternalIntegration.TOKEN_AUTH, self.external_integration(self._db) ) auth_setting.value = link.href def extract_feed_data( self, feed: str | opds2_ast.OPDS2Feed, feed_url: str | None = None - ) -> tuple[dict, dict]: + ) -> tuple[dict[str, Metadata], dict[str, list[CoverageFailure] | CoverageFailure]]: """Turn an OPDS 2.0 feed into lists of Metadata and CirculationData objects. :param feed: OPDS 2.0 feed :param feed_url: Feed URL used to resolve relative links @@ -961,7 +962,7 @@ def extract_feed_data( parser_result = self._parser.parse_manifest(feed) feed = parser_result.root publication_metadata_dictionary = {} - failures: dict[str, list[CoverageFailure]] = {} + failures: dict[str, list[CoverageFailure] | CoverageFailure] = {} if feed.links: self._parse_feed_links(feed.links) @@ -1011,7 +1012,9 @@ class OPDS2ImportMonitor(OPDSImportMonitor): PROTOCOL = ExternalIntegration.OPDS2_IMPORT MEDIA_TYPE = OPDS2MediaTypesRegistry.OPDS_FEED.key, "application/json" - def _verify_media_type(self, url, status_code, headers, feed): + def _verify_media_type( + self, url: str, status_code: int, headers: Dict[str, str], feed: bytes + ) -> None: # Make sure we got an OPDS feed, and not an error page that was # sent with a 200 status code. media_type = headers.get("content-type") @@ -1024,7 +1027,7 @@ def _verify_media_type(self, url, status_code, headers, feed): url, message=message, debug_message=feed, status_code=status_code ) - def _get_accept_header(self): + def _get_accept_header(self) -> str: return "{}, {};q=0.9, */*;q=0.1".format( OPDS2MediaTypesRegistry.OPDS_FEED.key, "application/json" ) diff --git a/core/opds_import.py b/core/opds_import.py index faca6de823..3eeebcfd4a 100644 --- a/core/opds_import.py +++ b/core/opds_import.py @@ -2,12 +2,29 @@ import logging import traceback +from datetime import datetime from io import BytesIO -from typing import TYPE_CHECKING, Optional -from urllib.parse import ParseResult, urljoin, urlparse +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + overload, +) +from urllib.parse import urljoin, urlparse +from xml.etree.ElementTree import Element import dateutil import feedparser +from feedparser import FeedParserDict from flask_babel import lazy_gettext as _ from lxml import etree from pydantic import HttpUrl @@ -56,6 +73,7 @@ ) from .model.configuration import HasExternalIntegration from .monitor import CollectionMonitor +from .selftest import SelfTestResult from .util.datetime_helpers import datetime_utc, to_utc, utc_now from .util.http import HTTP, BadResponseException from .util.opds_writer import OPDSFeed, OPDSMessage @@ -66,7 +84,17 @@ from .model import Work -def parse_identifier(db, identifier): +@overload +def parse_identifier(db: Session, identifier: str) -> Identifier: + ... + + +@overload +def parse_identifier(db: Session, identifier: Optional[str]) -> Optional[Identifier]: + ... + + +def parse_identifier(db: Session, identifier: Optional[str]) -> Optional[Identifier]: """Parse the identifier and return an Identifier object representing it. :param db: Database session @@ -199,7 +227,9 @@ class OPDSImporterLibrarySettings(BaseSettings): pass -class OPDSImporter(CirculationConfigurationMixin): +class OPDSImporter( + CirculationConfigurationMixin[OPDSImporterSettings, OPDSImporterLibrarySettings] +): """Imports editions and license pools from an OPDS feed. Creates Edition, LicensePool and Work rows in the database, if those don't already exist. @@ -229,28 +259,30 @@ class OPDSImporter(CirculationConfigurationMixin): SUCCESS_STATUS_CODES: list[int] | None = None @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[OPDSImporterSettings]: return OPDSImporterSettings @classmethod - def library_settings_class(cls): + def library_settings_class(cls) -> Type[OPDSImporterLibrarySettings]: return OPDSImporterLibrarySettings - def label(self): + @classmethod + def label(cls) -> str: return "OPDS Importer" - def description(self): - return self.DESCRIPTION + @classmethod + def description(cls) -> str: + return cls.DESCRIPTION # type: ignore[no-any-return] def __init__( self, - _db, - collection, - data_source_name=None, - identifier_mapping=None, - http_get=None, - content_modifier=None, - map_from_collection=None, + _db: Session, + collection: Optional[Collection], + data_source_name: Optional[str] = None, + identifier_mapping: Optional[Dict[Identifier, Identifier]] = None, + http_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + content_modifier: Optional[Callable[..., None]] = None, + map_from_collection: Optional[bool] = None, ): """:param collection: LicensePools created by this OPDS import will be associated with the given Collection. If this is None, @@ -308,7 +340,7 @@ def __init__( self.map_from_collection = map_from_collection @property - def collection(self): + def collection(self) -> Optional[Collection]: """Returns an associated Collection object :return: Associated Collection object @@ -320,19 +352,21 @@ def collection(self): return None @property - def data_source(self): + def data_source(self) -> DataSource: """Look up or create a DataSource object representing the source of this OPDS feed. """ offers_licenses = self.collection is not None - return DataSource.lookup( + return DataSource.lookup( # type: ignore[no-any-return] self._db, self.data_source_name, autocreate=True, offers_licenses=offers_licenses, ) - def assert_importable_content(self, feed, feed_url, max_get_attempts=5): + def assert_importable_content( + self, feed: str, feed_url: str, max_get_attempts: int = 5 + ) -> Literal[True]: """Raise an exception if the given feed contains nothing that can, even theoretically, be turned into a LicensePool. @@ -352,7 +386,7 @@ def assert_importable_content(self, feed, feed_url, max_get_attempts=5): url = link.href success = self._is_open_access_link(url, link.media_type) if success: - return success + return True get_attempts += 1 if get_attempts >= max_get_attempts: error = ( @@ -368,7 +402,9 @@ def assert_importable_content(self, feed, feed_url, max_get_attempts=5): ) @classmethod - def _open_access_links(cls, metadatas): + def _open_access_links( + cls, metadatas: List[Metadata] + ) -> Generator[LinkData, None, None]: """Find all open-access links in a list of Metadata objects. :param metadatas: A list of Metadata objects. @@ -381,7 +417,9 @@ def _open_access_links(cls, metadatas): if link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD: yield link - def _is_open_access_link(self, url, type): + def _is_open_access_link( + self, url: str, type: Optional[str] + ) -> str | Literal[False]: """Is `url` really an open-access link? That is, can we make a normal GET request and get something @@ -403,7 +441,7 @@ def _is_open_access_link(self, url, type): ) return False - def _parse_identifier(self, identifier): + def _parse_identifier(self, identifier: str) -> Identifier: """Parse the identifier and return an Identifier object representing it. :param identifier: String containing the identifier @@ -414,14 +452,19 @@ def _parse_identifier(self, identifier): """ return parse_identifier(self._db, identifier) - def import_from_feed(self, feed, feed_url=None): + def import_from_feed( + self, feed: str | bytes, feed_url: Optional[str] = None + ) -> Tuple[ + List[Edition], + List[LicensePool], + List[Work], + Dict[str, CoverageFailure | List[CoverageFailure]], + ]: # Keep track of editions that were imported. Pools and works # for those editions may be looked up or created. imported_editions = {} pools = {} works = {} - # CoverageFailures that note business logic errors and non-success download statuses - failures = {} # If parsing the overall feed throws an exception, we should address that before # moving on. Let the exception propagate. @@ -469,9 +512,10 @@ def import_from_feed(self, feed, feed_url=None): if work: works[key] = work except Exception as e: + collection_name = self.collection.name if self.collection else "None" logging.warning( f"Non-fatal exception: Failed to import item - import will continue: " - f"identifier={key}; collection={self.collection.name}; " + f"identifier={key}; collection={collection_name}/{self._collection_id}; " f"data_source={self.data_source}; exception={e}", stack_info=True, ) @@ -493,7 +537,7 @@ def import_from_feed(self, feed, feed_url=None): failures, ) - def import_edition_from_metadata(self, metadata): + def import_edition_from_metadata(self, metadata: Metadata) -> Edition: """For the passed-in Metadata object, see if can find or create an Edition in the database. Also create a LicensePool if the Metadata has CirculationData in it. @@ -517,12 +561,12 @@ def import_edition_from_metadata(self, metadata): replace=policy, ) - return edition + return edition # type: ignore[no-any-return] def update_work_for_edition( self, edition: Edition, - is_open_access=True, + is_open_access: bool = True, ) -> tuple[LicensePool | None, Work | None]: """If possible, ensure that there is a presentation-ready Work for the given edition's primary identifier. @@ -573,7 +617,7 @@ def update_work_for_edition( # background, and that's good enough. return pool, work - def extract_next_links(self, feed): + def extract_next_links(self, feed: str | bytes | FeedParserDict) -> List[str]: if isinstance(feed, (bytes, str)): parsed = feedparser.parse(feed) else: @@ -586,7 +630,9 @@ def extract_next_links(self, feed): ] return next_links - def extract_last_update_dates(self, feed): + def extract_last_update_dates( + self, feed: str | bytes | FeedParserDict + ) -> List[Tuple[Optional[str], Optional[datetime]]]: if isinstance(feed, (bytes, str)): parsed_feed = feedparser.parse(feed) else: @@ -597,7 +643,7 @@ def extract_last_update_dates(self, feed): ] return [x for x in dates if x and x[1]] - def build_identifier_mapping(self, external_urns): + def build_identifier_mapping(self, external_urns: List[str]) -> None: """Uses the given Collection and a list of URNs to reverse engineer an identifier mapping. @@ -632,7 +678,9 @@ def build_identifier_mapping(self, external_urns): self.identifier_mapping = mapping - def extract_feed_data(self, feed, feed_url=None): + def extract_feed_data( + self, feed: str | bytes, feed_url: Optional[str] = None + ) -> Tuple[Dict[str, Metadata], Dict[str, CoverageFailure | List[CoverageFailure]]]: """Turn an OPDS feed into lists of Metadata and CirculationData objects, with associated messages and next_links. """ @@ -652,16 +700,16 @@ def extract_feed_data(self, feed, feed_url=None): ) # translate the id in failures to identifier.urn - identified_failures = {} + identified_failures: Dict[str, CoverageFailure | List[CoverageFailure]] = {} for urn, failure in list(fp_failures.items()) + list(xml_failures.items()): identifier, failure = self.handle_failure(urn, failure) identified_failures[identifier.urn] = failure # Use one loop for both, since the id will be the same for both dictionaries. metadata = {} - circulationdata = {} - for id, m_data_dict in list(fp_metadata.items()): - xml_data_dict = xml_data_meta.get(id, {}) + _id: str + for _id, m_data_dict in list(fp_metadata.items()): + xml_data_dict = xml_data_meta.get(_id, {}) external_identifier = None if self.primary_identifier_source == ExternalIntegration.DCTERMS_IDENTIFIER: @@ -677,7 +725,7 @@ def extract_feed_data(self, feed, feed_url=None): # the external identifier will be add later, so it must be removed at this point new_identifiers = dcterms_ids[1:] # Id must be in the identifiers with lower weight. - id_type, id_identifier = Identifier.type_and_identifier_for_urn(id) + id_type, id_identifier = Identifier.type_and_identifier_for_urn(_id) id_weight = 1 new_identifiers.append( IdentifierData(id_type, id_identifier, id_weight) @@ -685,9 +733,10 @@ def extract_feed_data(self, feed, feed_url=None): xml_data_dict["identifiers"] = new_identifiers if external_identifier is None: - external_identifier, ignore = Identifier.parse_urn(self._db, id) + external_identifier, ignore = Identifier.parse_urn(self._db, _id) - if self.identifier_mapping: + internal_identifier: Optional[Identifier] + if self.identifier_mapping and external_identifier is not None: internal_identifier = self.identifier_mapping.get( external_identifier, external_identifier ) @@ -753,7 +802,21 @@ def extract_feed_data(self, feed, feed_url=None): pass return metadata, identified_failures - def handle_failure(self, urn, failure): + @overload + def handle_failure( + self, urn: str, failure: Identifier + ) -> Tuple[Identifier, Identifier]: + ... + + @overload + def handle_failure( + self, urn: str, failure: CoverageFailure + ) -> Tuple[Identifier, CoverageFailure]: + ... + + def handle_failure( + self, urn: str, failure: Identifier | CoverageFailure + ) -> Tuple[Identifier, CoverageFailure | Identifier]: """Convert a URN and a failure message that came in through an OPDS feed into an Identifier and a CoverageFailure object. @@ -785,7 +848,7 @@ def handle_failure(self, urn, failure): return internal_identifier, failure @classmethod - def _add_format_data(cls, circulation): + def _add_format_data(cls, circulation: CirculationData) -> None: """Subclasses that specialize OPDS Import can implement this method to add formats to a CirculationData object with information that allows a patron to actually get a book @@ -793,14 +856,16 @@ def _add_format_data(cls, circulation): """ @classmethod - def combine(self, d1, d2): + def combine( + self, d1: Optional[Dict[str, Any]], d2: Optional[Dict[str, Any]] + ) -> Dict[str, Any]: """Combine two dictionaries that can be used as keyword arguments to the Metadata constructor. """ if not d1 and not d2: return dict() if not d1: - return dict(d2) + return dict(d2) # type: ignore[arg-type] if not d2: return dict(d1) new_dict = dict(d1) @@ -828,7 +893,9 @@ def combine(self, d1, d2): pass return new_dict - def extract_data_from_feedparser(self, feed, data_source): + def extract_data_from_feedparser( + self, feed: str | bytes, data_source: DataSource + ) -> Tuple[Dict[str, Any], Dict[str, CoverageFailure]]: feedparser_parsed = feedparser.parse(feed) values = {} failures = {} @@ -849,15 +916,18 @@ def extract_data_from_feedparser(self, feed, data_source): # That's bad. Can't make an item-specific error message, but write to # log that something very wrong happened. logging.error( - "Tried to parse an element without a valid identifier. feed=%s" - % feed + f"Tried to parse an element without a valid identifier. feed={feed!r}" ) return values, failures @classmethod def extract_metadata_from_elementtree( - cls, feed, data_source, feed_url=None, do_get=None - ): + cls, + feed: bytes | str, + data_source: DataSource, + feed_url: Optional[str] = None, + do_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + ) -> Tuple[Dict[str, Any], Dict[str, CoverageFailure]]: """Parse the OPDS as XML and extract all author and subject information, as well as ratings and medium. @@ -903,30 +973,34 @@ def extract_metadata_from_elementtree( # Then turn Atom tags into Metadata objects. for entry in parser._xpath(root, "/atom:feed/atom:entry"): - identifier, detail, failure = cls.detail_for_elementtree_entry( + identifier, detail, failure_entry = cls.detail_for_elementtree_entry( parser, entry, data_source, feed_url, do_get=do_get ) if identifier: - if failure: - failures[identifier] = failure + if failure_entry: + failures[identifier] = failure_entry if detail: values[identifier] = detail return values, failures @classmethod - def _datetime(cls, entry, key): + def _datetime(cls, entry: Dict[str, str], key: str) -> Optional[datetime]: value = entry.get(key, None) if not value: - return value + return None return datetime_utc(*value[:6]) - def last_update_date_for_feedparser_entry(self, entry): + def last_update_date_for_feedparser_entry( + self, entry: Dict[str, Any] + ) -> Tuple[Optional[str], Optional[datetime]]: identifier = entry.get("id") updated = self._datetime(entry, "updated_parsed") - return (identifier, updated) + return identifier, updated @classmethod - def data_detail_for_feedparser_entry(cls, entry, data_source): + def data_detail_for_feedparser_entry( + cls, entry: Dict[str, str], data_source: DataSource + ) -> Tuple[Optional[str], Optional[Dict[str, Any]], Optional[CoverageFailure]]: """Turn an entry dictionary created by feedparser into dictionaries of data that can be used as keyword arguments to the Metadata and CirculationData constructors. @@ -950,7 +1024,9 @@ def data_detail_for_feedparser_entry(cls, entry, data_source): return identifier, None, failure @classmethod - def _data_detail_for_feedparser_entry(cls, entry, metadata_data_source): + def _data_detail_for_feedparser_entry( + cls, entry: Dict[str, Any], metadata_data_source: DataSource + ) -> Dict[str, Any]: """Helper method that extracts metadata and circulation data from a feedparser entry. This method can be overridden in tests to check that callers handle things properly when it throws an exception. @@ -1010,7 +1086,7 @@ def _data_detail_for_feedparser_entry(cls, entry, metadata_data_source): links = [] - def summary_to_linkdata(detail): + def summary_to_linkdata(detail: Optional[Dict[str, str]]) -> Optional[LinkData]: if not detail: return None if not "value" in detail or not detail["value"]: @@ -1056,14 +1132,14 @@ def summary_to_linkdata(detail): return kwargs_meta @classmethod - def rights_uri(cls, rights_string): + def rights_uri(cls, rights_string: str) -> str: """Determine the URI that best encapsulates the rights status of the downloads associated with this book. """ return RightsStatus.rights_uri_from_string(rights_string) @classmethod - def rights_uri_from_feedparser_entry(cls, entry): + def rights_uri_from_feedparser_entry(cls, entry: Dict[str, str]) -> str: """Extract a rights URI from a parsed feedparser entry. :return: A rights URI. @@ -1072,17 +1148,20 @@ def rights_uri_from_feedparser_entry(cls, entry): return cls.rights_uri(rights) @classmethod - def rights_uri_from_entry_tag(cls, entry): + def rights_uri_from_entry_tag(cls, entry: Element) -> Optional[str]: """Extract a rights string from an lxml tag. :return: A rights URI. """ rights = cls.PARSER_CLASS._xpath1(entry, "rights") - if rights: - return cls.rights_uri(rights) + if rights is None: + return None + return cls.rights_uri(rights) @classmethod - def extract_messages(cls, parser, feed_tag): + def extract_messages( + cls, parser: OPDSXMLParser, feed_tag: str + ) -> Generator[OPDSMessage, None, None]: """Extract tags from an OPDS feed and convert them into OPDSMessage objects. """ @@ -1116,7 +1195,9 @@ def extract_messages(cls, parser, feed_tag): yield OPDSMessage(urn, status_code, description) @classmethod - def coveragefailures_from_messages(cls, data_source, parser, feed_tag): + def coveragefailures_from_messages( + cls, data_source: DataSource, parser: OPDSXMLParser, feed_tag: str + ) -> Generator[CoverageFailure, None, None]: """Extract CoverageFailure objects from a parsed OPDS document. This allows us to determine the fate of books which could not become tags. @@ -1127,7 +1208,9 @@ def coveragefailures_from_messages(cls, data_source, parser, feed_tag): yield failure @classmethod - def coveragefailure_from_message(cls, data_source, message): + def coveragefailure_from_message( + cls, data_source: DataSource, message: OPDSMessage + ) -> Optional[CoverageFailure]: """Turn a tag into a CoverageFailure.""" _db = Session.object_session(data_source) @@ -1149,7 +1232,7 @@ def coveragefailure_from_message(cls, data_source, message): if cls.SUCCESS_STATUS_CODES and message.status_code in cls.SUCCESS_STATUS_CODES: # This message is telling us that nothing went wrong. It # should be treated as a success. - return identifier + return identifier # type: ignore[no-any-return] if message.status_code == 200: # By default, we treat a message with a 200 status code @@ -1173,8 +1256,13 @@ def coveragefailure_from_message(cls, data_source, message): @classmethod def detail_for_elementtree_entry( - cls, parser, entry_tag, data_source, feed_url=None, do_get=None - ): + cls, + parser: OPDSXMLParser, + entry_tag: Element, + data_source: DataSource, + feed_url: Optional[str] = None, + do_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + ) -> Tuple[Optional[str], Optional[Dict[str, Any]], Optional[CoverageFailure]]: """Turn an tag into a dictionary of metadata that can be used as keyword arguments to the Metadata contructor. @@ -1203,15 +1291,19 @@ def detail_for_elementtree_entry( @classmethod def _detail_for_elementtree_entry( - cls, parser, entry_tag, feed_url=None, do_get=None - ): + cls, + parser: OPDSXMLParser, + entry_tag: Element, + feed_url: Optional[str] = None, + do_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + ) -> Dict[str, Any]: """Helper method that extracts metadata and circulation data from an elementtree entry. This method can be overridden in tests to check that callers handle things properly when it throws an exception. """ # We will fill this dictionary with all the information # we can find. - data = dict() + data: Dict[str, Any] = dict() alternate_identifiers = [] for id_tag in parser._xpath(entry_tag, "dcterms:identifier"): @@ -1236,9 +1328,9 @@ def _detail_for_elementtree_entry( ratings = [] for rating_tag in parser._xpath(entry_tag, "schema:Rating"): - v = cls.extract_measurement(rating_tag) - if v: - ratings.append(v) + measurement = cls.extract_measurement(rating_tag) + if measurement: + ratings.append(measurement) data["measurements"] = ratings rights_uri = cls.rights_uri_from_entry_tag(entry_tag) @@ -1271,7 +1363,7 @@ def _detail_for_elementtree_entry( return data @classmethod - def get_medium_from_links(cls, links): + def get_medium_from_links(cls, links: List[LinkData]) -> Optional[str]: """Get medium if derivable from information in an acquisition link.""" derived = None for link in links: @@ -1287,9 +1379,11 @@ def get_medium_from_links(cls, links): return derived @classmethod - def extract_identifier(cls, identifier_tag): + def extract_identifier(cls, identifier_tag: Element) -> Optional[IdentifierData]: """Turn a tag into an IdentifierData object.""" try: + if identifier_tag.text is None: + return None type, identifier = Identifier.type_and_identifier_for_urn( identifier_tag.text.lower() ) @@ -1298,7 +1392,9 @@ def extract_identifier(cls, identifier_tag): return None @classmethod - def extract_medium(cls, entry_tag, default=Edition.BOOK_MEDIUM): + def extract_medium( + cls, entry_tag: Optional[Element], default: Optional[str] = Edition.BOOK_MEDIUM + ) -> Optional[str]: """Derive a value for Edition.medium from schema:additionalType or from a subtag. @@ -1320,7 +1416,9 @@ def extract_medium(cls, entry_tag, default=Edition.BOOK_MEDIUM): return medium or default @classmethod - def extract_contributor(cls, parser, author_tag): + def extract_contributor( + cls, parser: OPDSXMLParser, author_tag: Element + ) -> Optional[ContributorData]: """Turn an tag into a ContributorData object.""" subtag = parser.text_of_optional_subtag sort_name = subtag(author_tag, "simplified:sort_name") @@ -1350,14 +1448,16 @@ def extract_contributor(cls, parser, author_tag): return None @classmethod - def extract_subject(cls, parser, category_tag): + def extract_subject( + cls, parser: OPDSXMLParser, category_tag: Element + ) -> SubjectData: """Turn an tag into a SubjectData object.""" attr = category_tag.attrib # Retrieve the type of this subject - FAST, Dewey Decimal, # etc. scheme = attr.get("scheme") - subject_type = Subject.by_uri.get(scheme) + subject_type = Subject.by_uri.get(scheme) # type: ignore[arg-type] if not subject_type: # We can't represent this subject because we don't # know its scheme. Just treat it as a tag. @@ -1378,7 +1478,12 @@ def extract_subject(cls, parser, category_tag): return SubjectData(type=subject_type, identifier=term, name=name, weight=weight) @classmethod - def extract_link(cls, link_tag, feed_url=None, entry_rights_uri=None): + def extract_link( + cls, + link_tag: Element, + feed_url: Optional[str] = None, + entry_rights_uri: Optional[str] = None, + ) -> Optional[LinkData]: """Convert a tag into a LinkData object. :param feed_url: The URL to the enclosing feed, for use in resolving @@ -1398,12 +1503,12 @@ def extract_link(cls, link_tag, feed_url=None, entry_rights_uri=None): # relationship to the entry. return None rights = attr.get("{%s}rights" % OPDSXMLParser.NAMESPACES["dcterms"]) + rights_uri = entry_rights_uri if rights: # Rights associated with the link override rights # associated with the entry. rights_uri = cls.rights_uri(rights) - else: - rights_uri = entry_rights_uri + if feed_url and not urlparse(href).netloc: # This link is relative, so we need to get the absolute url href = urljoin(feed_url, href) @@ -1411,8 +1516,13 @@ def extract_link(cls, link_tag, feed_url=None, entry_rights_uri=None): @classmethod def make_link_data( - cls, rel, href=None, media_type=None, rights_uri=None, content=None - ): + cls, + rel: str, + href: Optional[str] = None, + media_type: Optional[str] = None, + rights_uri: Optional[str] = None, + content: Optional[str] = None, + ) -> LinkData: """Hook method for creating a LinkData object. Intended to be overridden in subclasses. @@ -1426,13 +1536,13 @@ def make_link_data( ) @classmethod - def consolidate_links(cls, links): + def consolidate_links(cls, links: Sequence[LinkData | None]) -> List[LinkData]: """Try to match up links with their thumbnails. If link n is an image and link n+1 is a thumbnail, then the thumbnail is assumed to be the thumbnail of the image. - Similarly if link n is a thumbnail and link n+1 is an image. + Similarly, if link n is a thumbnail and link n+1 is an image. """ # Strip out any links that didn't get turned into LinkData objects # due to missing `href` or whatever. @@ -1441,10 +1551,10 @@ def consolidate_links(cls, links): # Make a new list of links from that list, to iterate over -- # we'll be modifying new_links in place so we can't iterate # over it. - links = list(new_links) + _links = list(new_links) next_link_already_handled = False - for i, link in enumerate(links): + for i, link in enumerate(_links): if link.rel not in (Hyperlink.THUMBNAIL_IMAGE, Hyperlink.IMAGE): # This is not any kind of image. Ignore it. continue @@ -1455,13 +1565,13 @@ def consolidate_links(cls, links): next_link_already_handled = False continue - if i == len(links) - 1: + if i == len(_links) - 1: # This is the last link. Since there is no next link # there's nothing to do here. continue # Peek at the next link. - next_link = links[i + 1] + next_link = _links[i + 1] if ( link.rel == Hyperlink.THUMBNAIL_IMAGE @@ -1489,24 +1599,28 @@ def consolidate_links(cls, links): return new_links @classmethod - def extract_measurement(cls, rating_tag): + def extract_measurement(cls, rating_tag: Element) -> Optional[MeasurementData]: type = rating_tag.get("{http://schema.org/}additionalType") value = rating_tag.get("{http://schema.org/}ratingValue") if not value: value = rating_tag.attrib.get("{http://schema.org}ratingValue") if not type: type = Measurement.RATING + + if value is None: + return None + try: - value = float(value) + float_value = float(value) return MeasurementData( quantity_measured=type, - value=value, + value=float_value, ) except ValueError: return None @classmethod - def extract_series(cls, series_tag): + def extract_series(cls, series_tag: Element) -> Tuple[Optional[str], Optional[str]]: attr = series_tag.attrib series_name = attr.get("{http://schema.org/}name", None) series_position = attr.get("{http://schema.org/}position", None) @@ -1532,12 +1646,12 @@ class OPDSImportMonitor( def __init__( self, - _db, + _db: Session, collection: Collection, - import_class, - force_reimport=False, - **import_class_kwargs, - ): + import_class: Type[OPDSImporter], + force_reimport: bool = False, + **import_class_kwargs: Any, + ) -> None: if not collection: raise ValueError( "OPDSImportMonitor can only be run in the context of a Collection." @@ -1556,7 +1670,9 @@ def __init__( ) self.external_integration_id = collection.external_integration.id - self.feed_url = self.opds_url(collection) + feed_url = self.opds_url(collection) + self.feed_url = "" if feed_url is None else feed_url + self.force_reimport = force_reimport self.importer = import_class(_db, collection=collection, **import_class_kwargs) @@ -1576,14 +1692,14 @@ def __init__( except AttributeError: self._max_retry_count = 0 - parsed_url: ParseResult = urlparse(self.feed_url) + parsed_url = urlparse(self.feed_url) self._feed_base_url = f"{parsed_url.scheme}://{parsed_url.hostname}{(':' + str(parsed_url.port)) if parsed_url.port else ''}/" super().__init__(_db, collection) - def external_integration(self, _db): + def external_integration(self, _db: Session) -> Optional[ExternalIntegration]: return get_one(_db, ExternalIntegration, id=self.external_integration_id) - def _run_self_tests(self, _db): + def _run_self_tests(self, _db: Session) -> Generator[SelfTestResult, None, None]: """Retrieve the first page of the OPDS feed""" first_page = self.run_test( "Retrieve the first page of the OPDS feed (%s)" % self.feed_url, @@ -1606,7 +1722,9 @@ def _run_self_tests(self, _db): self.feed_url, ) - def _get(self, url, headers): + def _get( + self, url: str, headers: Dict[str, str] + ) -> Tuple[int, Dict[str, str], bytes]: """Make the sort of HTTP request that's normal for an OPDS feed. Long timeout, raise error on anything but 2xx or 3xx. @@ -1621,9 +1739,9 @@ def _get(self, url, headers): if not url.startswith("http"): url = urljoin(self._feed_base_url, url) response = HTTP.get_with_timeout(url, headers=headers, **kwargs) - return response.status_code, response.headers, response.content + return response.status_code, response.headers, response.content # type: ignore[return-value] - def _get_accept_header(self): + def _get_accept_header(self) -> str: return ",".join( [ OPDSFeed.ACQUISITION_FEED_TYPE, @@ -1633,7 +1751,7 @@ def _get_accept_header(self): ] ) - def _update_headers(self, headers): + def _update_headers(self, headers: Optional[Dict[str, str]]) -> Dict[str, str]: headers = dict(headers) if headers else {} if self.username and self.password and not "Authorization" in headers: headers["Authorization"] = "Basic %s" % base64.b64encode( @@ -1647,7 +1765,7 @@ def _update_headers(self, headers): return headers - def _parse_identifier(self, identifier): + def _parse_identifier(self, identifier: Optional[str]) -> Optional[Identifier]: """Extract the publication's identifier from its metadata. :param identifier: String containing the identifier @@ -1658,7 +1776,7 @@ def _parse_identifier(self, identifier): """ return parse_identifier(self._db, identifier) - def opds_url(self, collection): + def opds_url(self, collection: Collection) -> Optional[str]: """Returns the OPDS import URL for the given collection. By default, this URL is stored as the external account ID, but @@ -1666,15 +1784,15 @@ def opds_url(self, collection): """ return collection.external_account_id - def data_source(self, collection): + def data_source(self, collection: Collection) -> Optional[DataSource]: """Returns the data source name for the given collection. By default, this URL is stored as a setting on the collection, but subclasses may hard-code it. """ - return collection.data_source + return collection.data_source # type: ignore[no-any-return] - def feed_contains_new_data(self, feed): + def feed_contains_new_data(self, feed: bytes | str) -> bool: """Does the given feed contain any entries that haven't been imported yet? """ @@ -1704,7 +1822,9 @@ def feed_contains_new_data(self, feed): break return new_data - def identifier_needs_import(self, identifier, last_updated_remote): + def identifier_needs_import( + self, identifier: Optional[Identifier], last_updated_remote: Optional[datetime] + ) -> bool: """Does the remote side have new information about this Identifier? :param identifier: An Identifier. @@ -1766,8 +1886,11 @@ def identifier_needs_import(self, identifier, last_updated_remote): last_updated_remote, ) return True + return False - def _verify_media_type(self, url, status_code, headers, feed): + def _verify_media_type( + self, url: str, status_code: int, headers: Dict[str, str], feed: bytes + ) -> None: # Make sure we got an OPDS feed, and not an error page that was # sent with a 200 status code. media_type = headers.get("content-type") @@ -1779,7 +1902,9 @@ def _verify_media_type(self, url, status_code, headers, feed): url, message=message, debug_message=feed, status_code=status_code ) - def follow_one_link(self, url, do_get=None): + def follow_one_link( + self, url: str, do_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None + ) -> Tuple[List[str], Optional[bytes]]: """Download a representation of a URL and extract the useful information. @@ -1806,7 +1931,9 @@ def follow_one_link(self, url, do_get=None): self.log.info("No new data.") return [], None - def import_one_feed(self, feed): + def import_one_feed( + self, feed: bytes | str + ) -> Tuple[List[Edition], Dict[str, CoverageFailure | List[CoverageFailure]]]: """Import every book mentioned in an OPDS feed.""" # Because we are importing into a Collection, we will immediately @@ -1827,6 +1954,7 @@ def import_one_feed(self, feed): # Create CoverageRecords for the failures. for urn, failure in list(failures.items()): + failure_items: List[CoverageFailure] if isinstance(failure, list): failure_items = failure else: @@ -1839,7 +1967,7 @@ def import_one_feed(self, feed): return imported_editions, failures - def _get_feeds(self): + def _get_feeds(self) -> Iterable[Tuple[str, bytes]]: feeds = [] queue = [self.feed_url] seen_links = set() @@ -1863,11 +1991,9 @@ def _get_feeds(self): # Start importing at the end. If something fails, it will be easier to # pick up where we left off. - feeds = reversed(feeds) - - return feeds + return reversed(feeds) - def run_once(self, progress_ignore): + def run_once(self, progress: TimestampData) -> TimestampData: feeds = self._get_feeds() total_imported = 0 total_failures = 0 diff --git a/core/util/datetime_helpers.py b/core/util/datetime_helpers.py index 4d236984cc..6acfdc0a51 100644 --- a/core/util/datetime_helpers.py +++ b/core/util/datetime_helpers.py @@ -1,5 +1,5 @@ import datetime -from typing import Optional, Tuple +from typing import Optional, Tuple, overload import pytz from dateutil.relativedelta import relativedelta @@ -35,6 +35,16 @@ def utc_now() -> datetime.datetime: return datetime.datetime.now(tz=pytz.UTC) +@overload +def to_utc(dt: datetime.datetime) -> datetime.datetime: + ... + + +@overload +def to_utc(dt: Optional[datetime.datetime]) -> Optional[datetime.datetime]: + ... + + def to_utc(dt: Optional[datetime.datetime]) -> Optional[datetime.datetime]: """This converts a naive datetime object that represents UTC into an aware datetime object. diff --git a/core/util/xmlparser.py b/core/util/xmlparser.py index 1c3e11262f..2f3f998649 100644 --- a/core/util/xmlparser.py +++ b/core/util/xmlparser.py @@ -1,8 +1,16 @@ +from __future__ import annotations + from io import BytesIO -from typing import Dict +from typing import TYPE_CHECKING, Dict, List, Optional, TypeVar from lxml import etree +if TYPE_CHECKING: + from lxml.etree import Element + + +T = TypeVar("T") + class XMLParser: @@ -11,44 +19,56 @@ class XMLParser: NAMESPACES: Dict[str, str] = {} @classmethod - def _xpath(cls, tag, expression, namespaces=None): + def _xpath( + cls, tag: Element, expression: str, namespaces: Optional[Dict[str, str]] = None + ) -> List[Element]: if not namespaces: namespaces = cls.NAMESPACES """Wrapper to do a namespaced XPath expression.""" return tag.xpath(expression, namespaces=namespaces) @classmethod - def _xpath1(cls, tag, expression, namespaces=None): + def _xpath1( + cls, tag: Element, expression: str, namespaces: Optional[Dict[str, str]] = None + ) -> Optional[Element]: """Wrapper to do a namespaced XPath expression.""" values = cls._xpath(tag, expression, namespaces=namespaces) if not values: return None return values[0] - def _cls(self, tag_name, class_name): + def _cls(self, tag_name: str, class_name: str) -> str: """Return an XPath expression that will find a tag with the given CSS class.""" return ( 'descendant-or-self::node()/%s[contains(concat(" ", normalize-space(@class), " "), " %s ")]' % (tag_name, class_name) ) - def text_of_optional_subtag(self, tag, name, namespaces=None): + def text_of_optional_subtag( + self, tag: Element, name: str, namespaces: Optional[Dict[str, str]] = None + ) -> Optional[str]: tag = self._xpath1(tag, name, namespaces=namespaces) if tag is None or tag.text is None: return None else: return str(tag.text) - def text_of_subtag(self, tag, name, namespaces=None): + def text_of_subtag( + self, tag: Element, name: str, namespaces: Optional[Dict[str, str]] = None + ) -> str: return str(tag.xpath(name, namespaces=namespaces)[0].text) - def int_of_subtag(self, tag, name, namespaces=None): + def int_of_subtag( + self, tag: Element, name: str, namespaces: Optional[Dict[str, str]] = None + ) -> int: return int(self.text_of_subtag(tag, name, namespaces=namespaces)) - def int_of_optional_subtag(self, tag, name, namespaces=None): + def int_of_optional_subtag( + self, tag: Element, name: str, namespaces: Optional[Dict[str, str]] = None + ) -> Optional[int]: v = self.text_of_optional_subtag(tag, name, namespaces=namespaces) if not v: - return v + return None return int(v) def process_all(self, xml, xpath, namespaces=None, handler=None, parser=None): diff --git a/pyproject.toml b/pyproject.toml index 9fc86b06f8..7afd0b0d7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,12 +78,18 @@ module = [ "api.circulation", "api.discovery.*", "api.integration.*", + "api.lcp.hash", + "api.odl", + "api.odl2", + "api.opds_for_distributors", "core.feed.*", "core.integration.*", "core.model.announcements", "core.model.hassessioncache", "core.model.integration", "core.model.library", + "core.opds2_import", + "core.opds_import", "core.selftest", "core.service.*", "core.settings.*", diff --git a/tests/api/feed/test_library_annotator.py b/tests/api/feed/test_library_annotator.py index 2670225ac2..694fabd407 100644 --- a/tests/api/feed/test_library_annotator.py +++ b/tests/api/feed/test_library_annotator.py @@ -946,7 +946,7 @@ def test_active_loan_feed( tree = etree.fromstring(response.get_data(as_text=True)) parser = OPDSXMLParser() licensor = parser._xpath1(tree, "//atom:feed/drm:licensor") - + assert licensor is not None adobe_patron_identifier = AuthdataUtility._adobe_patron_identifier(patron) # The DRM licensing information includes the Adobe vendor ID @@ -1021,7 +1021,11 @@ def test_active_loan_feed( ) assert 2 == len(acquisitions) - availabilities = [parser._xpath1(x, "opds:availability") for x in acquisitions] + availabilities = [] + for x in acquisitions: + availability = parser._xpath1(x, "opds:availability") + assert availability is not None + availabilities.append(availability) # One of these availability tags has 'since' but not 'until'. # The other one has both. diff --git a/tests/api/test_odl.py b/tests/api/test_odl.py index b483f71901..518e6cf4de 100644 --- a/tests/api/test_odl.py +++ b/tests/api/test_odl.py @@ -1242,7 +1242,7 @@ def test_release_hold_success( odl_api_test_fixture.checkout(patron=loan_patron) odl_api_test_fixture.pool.on_hold_to(odl_api_test_fixture.patron, position=1) - assert True == odl_api_test_fixture.api.release_hold( + odl_api_test_fixture.api.release_hold( odl_api_test_fixture.patron, "pin", odl_api_test_fixture.pool ) assert 0 == odl_api_test_fixture.pool.licenses_available @@ -1253,7 +1253,7 @@ def test_release_hold_success( odl_api_test_fixture.pool.on_hold_to(odl_api_test_fixture.patron, position=0) odl_api_test_fixture.checkin(patron=loan_patron) - assert True == odl_api_test_fixture.api.release_hold( + odl_api_test_fixture.api.release_hold( odl_api_test_fixture.patron, "pin", odl_api_test_fixture.pool ) assert 1 == odl_api_test_fixture.pool.licenses_available @@ -1266,7 +1266,7 @@ def test_release_hold_success( db.patron(), position=2 ) - assert True == odl_api_test_fixture.api.release_hold( + odl_api_test_fixture.api.release_hold( odl_api_test_fixture.patron, "pin", odl_api_test_fixture.pool ) assert 0 == odl_api_test_fixture.pool.licenses_available diff --git a/tests/api/test_opds.py b/tests/api/test_opds.py index ba51138402..3decbf4040 100644 --- a/tests/api/test_opds.py +++ b/tests/api/test_opds.py @@ -1226,6 +1226,7 @@ def test_active_loan_feed( tree = etree.fromstring(response.get_data(as_text=True)) parser = OPDSXMLParser() licensor = parser._xpath1(tree, "//atom:feed/drm:licensor") + assert licensor is not None adobe_patron_identifier = AuthdataUtility._adobe_patron_identifier(patron) @@ -1294,7 +1295,11 @@ def test_active_loan_feed( ) assert 2 == len(acquisitions) - availabilities = [parser._xpath1(x, "opds:availability") for x in acquisitions] + availabilities = [] + for acquisition in acquisitions: + availability = parser._xpath1(acquisition, "opds:availability") + assert availability is not None + availabilities.append(availability) # One of these availability tags has 'since' but not 'until'. # The other one has both. @@ -1862,6 +1867,7 @@ def test_acquisition_links( opds_parser = OPDSXMLParser() availability = opds_parser._xpath1(fulfill, "opds:availability") + assert availability is not None assert _strftime(loan1.start) == availability.attrib.get("since") assert loan1.end == availability.attrib.get("until") assert None == loan1.end diff --git a/tests/api/test_opds_for_distributors.py b/tests/api/test_opds_for_distributors.py index a7ae2a868a..513fea0795 100644 --- a/tests/api/test_opds_for_distributors.py +++ b/tests/api/test_opds_for_distributors.py @@ -1,11 +1,10 @@ import datetime import json from typing import Callable, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -import core.opds_import from api.circulation_exceptions import * from api.opds_for_distributors import ( OPDSForDistributorsAPI, @@ -29,6 +28,7 @@ RightsStatus, Timestamp, create, + get_one, ) from core.util.datetime_helpers import utc_now from core.util.opds_writer import OPDSFeed @@ -147,7 +147,7 @@ def test_can_fulfill_without_loan( fulfilled with no underlying loan, if its delivery mechanism uses bearer token fulfillment. """ - patron = object() + patron = MagicMock() pool = opds_dist_api_fixture.db.licensepool( edition=None, collection=opds_dist_api_fixture.collection ) @@ -156,11 +156,11 @@ def test_can_fulfill_without_loan( m = opds_dist_api_fixture.api.can_fulfill_without_loan # No LicensePoolDeliveryMechanism -> False - assert False == m(patron, pool, None) + assert False == m(patron, pool, MagicMock()) # No LicensePool -> False (there can be multiple LicensePools for # a single LicensePoolDeliveryMechanism). - assert False == m(patron, None, lpdm) + assert False == m(patron, MagicMock(), lpdm) # No DeliveryMechanism -> False old_dm = lpdm.delivery_mechanism @@ -410,6 +410,7 @@ def test_checkout(self, opds_dist_api_fixture: OPDSForDistributorsAPIFixture): # The loan's start date has been set to the current time. now = utc_now() + assert loan_info.start_date is not None assert (now - loan_info.start_date).seconds < 2 # The loan is of indefinite duration. @@ -471,6 +472,7 @@ def test_fulfill(self, opds_dist_api_fixture: OPDSForDistributorsAPIFixture): assert None == fulfillment_info.content_link assert DeliveryMechanism.BEARER_TOKEN == fulfillment_info.content_type + assert fulfillment_info.content is not None bearer_token_document = json.loads(fulfillment_info.content) expires_in = bearer_token_document["expires_in"] assert expires_in < 60 @@ -483,6 +485,7 @@ def test_fulfill(self, opds_dist_api_fixture: OPDSForDistributorsAPIFixture): # bearer token expires to the time at which the title was # originally fulfilled. expect_expiration = fulfillment_time + datetime.timedelta(seconds=expires_in) + assert fulfillment_info.content_expires is not None assert ( abs((fulfillment_info.content_expires - expect_expiration).total_seconds()) < 5 @@ -708,9 +711,7 @@ def setup_collection(*, name: str, datasource: DataSource) -> Collection: collection=collection2, ) - with patch( - "core.opds_import.get_one", wraps=core.opds_import.get_one - ) as get_one_mock: + with patch("core.opds_import.get_one", wraps=get_one) as get_one_mock: importer1_lp, _ = importer1.update_work_for_edition(edition) importer2_lp, _ = importer2.update_work_for_edition(edition) diff --git a/tests/api/test_selftest.py b/tests/api/test_selftest.py index 1865478fff..b7eea047ee 100644 --- a/tests/api/test_selftest.py +++ b/tests/api/test_selftest.py @@ -5,6 +5,7 @@ from io import StringIO from typing import TYPE_CHECKING from unittest import mock +from unittest.mock import MagicMock import pytest @@ -302,7 +303,7 @@ def _no_delivery_mechanisms_test(self): return "1" mock = Mock() - results = [x for x in mock._run_self_tests()] + results = [x for x in mock._run_self_tests(MagicMock())] assert ["1"] == [x.result for x in results] assert True == mock._no_delivery_mechanisms_called diff --git a/tests/core/test_opds.py b/tests/core/test_opds.py index 21c3a96faf..ce3d463672 100644 --- a/tests/core/test_opds.py +++ b/tests/core/test_opds.py @@ -855,12 +855,14 @@ def test_acquisition_feed_includes_available_and_issued_tag( entries = OPDSXMLParser._xpath(with_times, "/atom:feed/atom:entry") parsed = [] for entry in entries: - title = OPDSXMLParser._xpath1(entry, "atom:title").text + title_element = OPDSXMLParser._xpath1(entry, "atom:title") + assert title_element is not None + title = title_element.text issued = OPDSXMLParser._xpath1(entry, "dcterms:issued") - if issued != None: + if issued is not None: issued = issued.text published = OPDSXMLParser._xpath1(entry, "atom:published") - if published != None: + if published is not None: published = published.text parsed.append( dict( diff --git a/tests/core/test_opds_import.py b/tests/core/test_opds_import.py index 4a4a8369f1..2a4088b890 100644 --- a/tests/core/test_opds_import.py +++ b/tests/core/test_opds_import.py @@ -1,7 +1,7 @@ import random from io import StringIO from typing import Optional -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest import requests_mock @@ -118,7 +118,7 @@ def test_constructor(self, opds_importer_fixture: OPDSImporterFixture): assert Representation.cautious_http_get == importer.http_get # But you can pass in anything you want. - do_get = object() + do_get = MagicMock() importer = OPDSImporter(session, collection=None, http_get=do_get) assert do_get == importer.http_get @@ -229,6 +229,7 @@ def test_extract_metadata(self, opds_importer_fixture: OPDSImporterFixture): assert data_source_name == c2._data_source [failure] = list(failures.values()) + assert isinstance(failure, CoverageFailure) assert ( "202: I'm working to locate a source for this identifier." == failure.exception @@ -260,10 +261,10 @@ def test_use_dcterm_identifier_as_id_with_id_and_dcterms_identifier( # First book doesn't have , so must be used as identifier book_1 = metadata.get("https://root.uri/1") - assert book_1 != None + assert book_1 is not None # Second book have and , so must be used as id book_2 = metadata.get("urn:isbn:9781468316438") - assert book_2 != None + assert book_2 is not None # Verify if id was add in the end of identifier book_2_identifiers = book_2.identifiers found = False @@ -271,10 +272,10 @@ def test_use_dcterm_identifier_as_id_with_id_and_dcterms_identifier( if entry.identifier == "https://root.uri/2": found = True break - assert found == True + assert found is True # Third book has more than one dcterms:identifers, all of then must be present as metadata identifier book_3 = metadata.get("urn:isbn:9781683351993") - assert book_2 != None + assert book_3 is not None # Verify if id was add in the end of identifier book_3_identifiers = book_3.identifiers expected_identifier = [ @@ -857,7 +858,7 @@ def test_import(self, opds_importer_fixture: OPDSImporterFixture): session, collection=None ).import_from_feed(feed) - [crow, mouse] = sorted(imported_editions, key=lambda x: x.title) + [crow, mouse] = sorted(imported_editions, key=lambda x: str(x.title)) # By default, this feed is treated as though it came from the # metadata wrangler. No Work has been created. @@ -873,7 +874,7 @@ def test_import(self, opds_importer_fixture: OPDSImporterFixture): # Three links have been added to the identifier of the 'mouse' # edition. image, thumbnail, description = sorted( - mouse.primary_identifier.links, key=lambda x: x.rel + mouse.primary_identifier.links, key=lambda x: str(x.rel) ) # A Representation was imported for the summary with known @@ -896,22 +897,24 @@ def test_import(self, opds_importer_fixture: OPDSImporterFixture): # Two links were added to the identifier of the 'crow' edition. [broken_image, working_image] = sorted( - crow.primary_identifier.links, key=lambda x: x.resource.url + crow.primary_identifier.links, key=lambda x: str(x.resource.url) ) # Because these images did not have a specified media type or a # distinctive extension, and we have not actually retrieved # the URLs yet, we were not able to determine their media type, # so they have no associated Representation. + assert broken_image.resource.url is not None assert broken_image.resource.url.endswith("/broken-cover-image") + assert working_image.resource.url is not None assert working_image.resource.url.endswith("/working-cover-image") - assert None == broken_image.resource.representation - assert None == working_image.resource.representation + assert broken_image.resource.representation is None + assert working_image.resource.representation is None # Three measurements have been added to the 'mouse' edition. popularity, quality, rating = sorted( (x for x in mouse.primary_identifier.measurements if x.is_most_recent), - key=lambda x: x.quantity_measured, + key=lambda x: str(x.quantity_measured), ) assert DataSource.METADATA_WRANGLER == popularity.data_source.name @@ -927,7 +930,7 @@ def test_import(self, opds_importer_fixture: OPDSImporterFixture): assert 0.6 == rating.value seven, children, courtship, fantasy, pz, magic, new_york = sorted( - mouse.primary_identifier.classifications, key=lambda x: x.subject.name + mouse.primary_identifier.classifications, key=lambda x: str(x.subject.name) ) pz_s = pz.subject @@ -1556,8 +1559,8 @@ class NoLinks(Mock): "Simulate an OPDS feed that contains no open-access links." open_access_links = [] - # We don't be making any HTTP requests, even simulated ones. - do_get = object() + # We won't be making any HTTP requests, even simulated ones. + do_get = MagicMock() # Here, there are no links at all. importer = NoLinks(session, None, do_get) @@ -1628,7 +1631,7 @@ def _is_open_access_link(self, url, type): result = good_link_importer.assert_importable_content( "feed", "url", max_get_attempts=5 ) - assert "this is a book" == result + assert True == result # The first link didn't work, but the second one did, # so we didn't try the third one. @@ -2055,7 +2058,7 @@ def follow_one_link(self, url): assert ( "some content", feed_url, - ) == monitor.importer.assert_importable_content_called_with + ) == monitor.importer.assert_importable_content_called_with # type: ignore[attr-defined] assert "looks good" == found_content.result def test_hook_methods(self, opds_importer_fixture: OPDSImporterFixture): @@ -2355,7 +2358,7 @@ def import_one_feed(self, feed): monitor.queue_response([["second next link"], "second page"]) monitor.queue_response([["next link"], "first page"]) - progress = monitor.run_once(object()) + progress = monitor.run_once(MagicMock()) # Feeds are imported in reverse order assert ["last page", "second page", "first page"] == monitor.imports