From 745b4b55a733c64ea7c5f930d7719f8f0a581ab3 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Fri, 22 Sep 2023 09:01:32 -0300 Subject: [PATCH] Type hint ODL importer classes --- api/circulation.py | 2 +- api/lcp/hash.py | 13 ++- api/odl.py | 214 +++++++++++++++++++++++++--------------- api/odl2.py | 69 +++++++------ core/metadata_layer.py | 2 +- core/model/licensing.py | 21 ++-- core/model/patron.py | 3 +- pyproject.toml | 3 + 8 files changed, 196 insertions(+), 131 deletions(-) diff --git a/api/circulation.py b/api/circulation.py index 9765b0801e..0a0392f258 100644 --- a/api/circulation.py +++ b/api/circulation.py @@ -436,7 +436,7 @@ def __init__( identifier_type: Optional[str], identifier: Optional[str], start_date: Optional[datetime.datetime], - end_date: datetime.datetime, + end_date: Optional[datetime.datetime], fulfillment_info: Optional[FulfillmentInfo] = None, external_identifier: Optional[str] = None, locked_to: Optional[DeliveryMechanismInfo] = None, diff --git a/api/lcp/hash.py b/api/lcp/hash.py index 4569dfa58d..5fd98a75c7 100644 --- a/api/lcp/hash.py +++ b/api/lcp/hash.py @@ -1,5 +1,5 @@ import hashlib -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from enum import Enum from core.exceptions import BaseError @@ -14,20 +14,19 @@ class HashingError(BaseError): """Raised in the case of errors occurred during hashing""" -class Hasher(metaclass=ABCMeta): +class Hasher(ABC): """Base class for all implementations of different hashing algorithms""" - def __init__(self, hashing_algorithm): + def __init__(self, hashing_algorithm: HashingAlgorithm) -> None: """Initializes a new instance of Hasher class :param hashing_algorithm: Hashing algorithm - :type hashing_algorithm: HashingAlgorithm """ self._hashing_algorithm = hashing_algorithm @abstractmethod - def hash(self, value): - raise NotImplementedError() + def hash(self, value: str) -> str: + ... class UniversalHasher(Hasher): @@ -49,5 +48,5 @@ def hash(self, value: str) -> str: class HasherFactory: - def create(self, hashing_algorithm): + def create(self, hashing_algorithm: HashingAlgorithm) -> Hasher: return UniversalHasher(hashing_algorithm) diff --git a/api/odl.py b/api/odl.py index 02e00010a1..5d42046bc4 100644 --- a/api/odl.py +++ b/api/odl.py @@ -1,15 +1,18 @@ +from __future__ import annotations + import binascii import datetime import json import logging import uuid -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import dateutil -import sqlalchemy from flask import url_for from flask_babel import lazy_gettext as _ +from lxml.etree import Element from pydantic import HttpUrl, PositiveInt +from requests import Response from sqlalchemy.sql.expression import or_ from uritemplate import URITemplate @@ -218,20 +221,22 @@ class ODLAPI( ] @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[ODLSettings]: return ODLSettings @classmethod - def library_settings_class(cls): + def library_settings_class(cls) -> Type[ODLLibrarySettings]: return ODLLibrarySettings - def label(self): - return self.NAME + @classmethod + def label(cls) -> str: + return cls.NAME - def description(self): - return self.DESCRIPTION + @classmethod + def description(cls) -> str: + return cls.DESCRIPTION # type: ignore[no-any-return] - def __init__(self, _db, collection): + def __init__(self, _db: Session, collection: Collection) -> None: super().__init__(_db, collection) if collection.protocol != self.NAME: raise ValueError( @@ -252,9 +257,7 @@ def __init__(self, _db, collection): self._credential_factory = LCPCredentialFactory() self._hasher_instance: Optional[Hasher] = None - def external_integration( - self, db: sqlalchemy.orm.session.Session - ) -> ExternalIntegration: + def external_integration(self, db: Session) -> ExternalIntegration: """Return an external integration associated with this object. :param db: Database session @@ -262,7 +265,9 @@ def external_integration( """ return self.collection.external_integration - def internal_format(self, delivery_mechanism): + def internal_format( # type: ignore[override] + self, delivery_mechanism: Optional[LicensePoolDeliveryMechanism] + ) -> Optional[LicensePoolDeliveryMechanism]: """Each consolidated copy is only available in one format, so we don't need a mapping to internal formats. """ @@ -280,23 +285,22 @@ def collection(self) -> Collection: raise ValueError(f"Collection not found: {self.collection_id}") return collection - def _get_hasher(self): + def _get_hasher(self) -> Hasher: """Returns a Hasher instance :return: Hasher instance - :rtype: hash.Hasher """ config = self.configuration() if self._hasher_instance is None: self._hasher_instance = self._hasher_factory.create( - config.encryption_algorithm + config.encryption_algorithm # type: ignore[arg-type] if config.encryption_algorithm else ODLAPIConstants.DEFAULT_ENCRYPTION_ALGORITHM ) return self._hasher_instance - def _get(self, url, headers=None): + def _get(self, url: str, headers: Optional[Dict[str, str]] = None) -> Response: """Make a normal HTTP request, but include an authentication header with the credentials for the collection. """ @@ -309,11 +313,11 @@ def _get(self, url, headers=None): return HTTP.get_with_timeout(url, headers=headers) - def _url_for(self, *args, **kwargs): + def _url_for(self, *args: Any, **kwargs: Any) -> str: """Wrapper around flask's url_for to be overridden for tests.""" return url_for(*args, **kwargs) - def get_license_status_document(self, loan): + def get_license_status_document(self, loan: Loan) -> Dict[str, Any]: """Get the License Status Document for a loan. For a new loan, create a local loan with no external identifier and @@ -360,9 +364,10 @@ def get_license_status_document(self, loan): ) config = self.configuration() - url_template = URITemplate(loan.license.checkout_url) + checkout_url = str(loan.license.checkout_url) + url_template = URITemplate(checkout_url) url = url_template.expand( - id=id, + id=str(id), checkout_id=checkout_id, patron_id=patron_id, expires=expires.isoformat(), @@ -384,9 +389,9 @@ def get_license_status_document(self, loan): raise BadResponseException( url, "License Status Document had an unknown status value." ) - return status_doc + return status_doc # type: ignore[no-any-return] - def checkin(self, patron, pin, licensepool): + def checkin(self, patron: Patron, pin: str, licensepool: LicensePool) -> bool: # type: ignore[override] """Return a loan early.""" _db = Session.object_session(patron) @@ -397,10 +402,10 @@ def checkin(self, patron, pin, licensepool): ) if loan.count() < 1: raise NotCheckedOut() - loan = loan.one() - return self._checkin(loan) + loan_result = loan.one() + return self._checkin(loan_result) - def _checkin(self, loan): + def _checkin(self, loan: Loan) -> bool: _db = Session.object_session(loan) doc = self.get_license_status_document(loan) status = doc.get("status") @@ -427,7 +432,7 @@ def _checkin(self, loan): # must be returned through the DRM system. If that's true, the # app will already be doing that on its own, so we'll silently # do nothing. - return + return False # Hit the distributor's return link. self._get(return_url) @@ -439,12 +444,18 @@ def _checkin(self, loan): # However, it might be because the loan has already been fulfilled # and must be returned through the DRM system, which the app will # do on its own, so we can ignore the problem. - loan = get_one(_db, Loan, id=loan.id) - if loan: - return + new_loan = get_one(_db, Loan, id=loan.id) + if new_loan: + return False return True - def checkout(self, patron, pin, licensepool, internal_format): + def checkout( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + internal_format: Optional[str], + ) -> LoanInfo: """Create a new loan.""" _db = Session.object_session(patron) @@ -457,18 +468,20 @@ def checkout(self, patron, pin, licensepool, internal_format): raise AlreadyCheckedOut() hold = get_one(_db, Hold, patron=patron, license_pool_id=licensepool.id) - loan = self._checkout(patron, licensepool, hold) + loan_obj = self._checkout(patron, licensepool, hold) return LoanInfo( licensepool.collection, licensepool.data_source.name, licensepool.identifier.type, licensepool.identifier.identifier, - loan.start, - loan.end, - external_identifier=loan.external_identifier, + loan_obj.start, + loan_obj.end, + external_identifier=loan_obj.external_identifier, ) - def _checkout(self, patron: Patron, licensepool, hold=None): + def _checkout( + self, patron: Patron, licensepool: LicensePool, hold: Optional[Hold] = None + ) -> Loan: _db = Session.object_session(patron) if not any(l for l in licensepool.licenses if not l.is_inactive): @@ -483,7 +496,9 @@ def _checkout(self, patron: Patron, licensepool, hold=None): # If there's a holds queue, the patron must have a non-expired hold # with position 0 to check out the book. if ( - not hold or hold.position > 0 or (hold.end and hold.end < utc_now()) + not hold + or (hold.position and hold.position > 0) + or (hold.end and hold.end < utc_now()) ) and licensepool.licenses_available < 1: raise NoAvailableCopies() @@ -534,27 +549,28 @@ def _checkout(self, patron: Patron, licensepool, hold=None): self.update_licensepool(licensepool) return loan - def fulfill(self, patron, pin, licensepool, internal_format, **kwargs): - """Get the actual resource file to the patron. - - :param kwargs: A container for arguments to fulfill() - which are not relevant to this vendor. - - :return: a FulfillmentInfo object. - """ + def fulfill( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + internal_format: Optional[str] = None, + part: Optional[str] = None, + fulfill_part_url: Optional[Callable[[Optional[str]], str]] = None, + ) -> FulfillmentInfo: + """Get the actual resource file to the patron.""" _db = Session.object_session(patron) loan = ( _db.query(Loan) .filter(Loan.patron == patron) .filter(Loan.license_pool_id == licensepool.id) - ) - loan = loan.one() + ).one() return self._fulfill(loan, internal_format) @staticmethod def _find_content_link_and_type( - links: List[Dict], + links: List[Dict[str, str]], drm_scheme: Optional[str], ) -> Tuple[Optional[str], Optional[str]]: """Find a content link with the type information corresponding to the selected delivery mechanism. @@ -647,7 +663,7 @@ def _count_holds_before(self, holdinfo: HoldInfo, pool: LicensePool) -> int: .count() ) - def _update_hold_data(self, hold: Hold): + def _update_hold_data(self, hold: Hold) -> None: pool: LicensePool = hold.license_pool holdinfo = HoldInfo( pool.collection, @@ -665,7 +681,7 @@ def _update_hold_data(self, hold: Hold): def _update_hold_end_date( self, holdinfo: HoldInfo, pool: LicensePool, library: Library - ): + ) -> None: _db = Session.object_session(pool) # First make sure the hold position is up-to-date, since we'll @@ -751,7 +767,7 @@ def _update_hold_end_date( days=default_reservation_period ) - def _update_hold_position(self, holdinfo: HoldInfo, pool: LicensePool): + def _update_hold_position(self, holdinfo: HoldInfo, pool: LicensePool) -> None: _db = Session.object_session(pool) loans_count = ( _db.query(Loan) @@ -774,7 +790,7 @@ def _update_hold_position(self, holdinfo: HoldInfo, pool: LicensePool): # Add 1 since position 0 indicates the hold is ready. holdinfo.hold_position = holds_count + 1 - def update_licensepool(self, licensepool: LicensePool): + def update_licensepool(self, licensepool: LicensePool) -> None: # Update the pool and the next holds in the queue when a license is reserved. licensepool.update_availability_from_licenses( analytics=self.analytics, @@ -786,11 +802,17 @@ def update_licensepool(self, licensepool: LicensePool): # This hold just got a reserved license. self._update_hold_data(hold) - def place_hold(self, patron, pin, licensepool, notification_email_address): + def place_hold( + self, + patron: Patron, + pin: str, + licensepool: LicensePool, + notification_email_address: Optional[str], + ) -> HoldInfo: """Create a new hold.""" return self._place_hold(patron, licensepool) - def _place_hold(self, patron, licensepool): + def _place_hold(self, patron: Patron, licensepool: LicensePool) -> HoldInfo: _db = Session.object_session(patron) # Make sure pool info is updated. @@ -810,14 +832,19 @@ def _place_hold(self, patron, licensepool): if hold is not None: raise AlreadyOnHold() - licensepool.patrons_in_hold_queue += 1 + patrons_in_hold_queue = ( + licensepool.patrons_in_hold_queue + if licensepool.patrons_in_hold_queue + else 0 + ) + licensepool.patrons_in_hold_queue = patrons_in_hold_queue + 1 holdinfo = HoldInfo( licensepool.collection, licensepool.data_source.name, licensepool.identifier.type, licensepool.identifier.identifier, utc_now(), - 0, + None, 0, ) library = patron.library @@ -825,7 +852,7 @@ def _place_hold(self, patron, licensepool): return holdinfo - def release_hold(self, patron, pin, licensepool): + def release_hold(self, patron: Patron, pin: str, licensepool: LicensePool) -> None: """Cancel a hold.""" _db = Session.object_session(patron) @@ -837,9 +864,9 @@ def release_hold(self, patron, pin, licensepool): ) if not hold: raise NotOnHold() - return self._release_hold(hold) + self._release_hold(hold) - def _release_hold(self, hold): + def _release_hold(self, hold: Hold) -> Literal[True]: # If the book was ready and the patron revoked the hold instead # of checking it out, but no one else had the book on hold, the # book is now available for anyone to check out. If someone else @@ -852,7 +879,7 @@ def _release_hold(self, hold): self.update_licensepool(licensepool) return True - def patron_activity(self, patron, pin): + def patron_activity(self, patron: Patron, pin: str) -> List[LoanInfo | HoldInfo]: """Look up non-expired loans for this collection in the database.""" _db = Session.object_session(patron) loans = ( @@ -904,7 +931,9 @@ def patron_activity(self, patron, pin): for hold in remaining_holds ] - def update_loan(self, loan, status_doc=None): + def update_loan( + self, loan: Loan, status_doc: Optional[Dict[str, Any]] = None + ) -> None: """Check a loan's status, and if it is no longer active, delete the loan and update its pool's availability. """ @@ -918,7 +947,8 @@ def update_loan(self, loan, status_doc=None): # but if the document came from a notification it hasn't been checked yet. if status not in self.STATUS_VALUES: raise BadResponseException( - "The License Status Document had an unknown status value." + str(loan.license.checkout_url), + "The License Status Document had an unknown status value.", ) if status in [ @@ -937,7 +967,9 @@ def update_loan(self, loan, status_doc=None): _db.delete(loan) self.update_licensepool(loan.license_pool) - def update_availability(self, licensepool): + def update_availability( # type: ignore[empty-body] + self, licensepool: LicensePool + ) -> Tuple[LicensePool, bool, bool]: pass @@ -975,11 +1007,13 @@ class ODLImporter(OPDSImporter): } @classmethod - def fetch_license_info(cls, document_link: str, do_get: Callable) -> Optional[dict]: + def fetch_license_info( + cls, document_link: str, do_get: Callable[..., Tuple[int, Any, bytes]] + ) -> Optional[Dict[str, Any]]: status_code, _, response = do_get(document_link, headers={}) if status_code in (200, 201): license_info_document = json.loads(response) - return license_info_document + return license_info_document # type: ignore[no-any-return] else: logging.warning( f"License Info Document is not available. " @@ -990,9 +1024,9 @@ def fetch_license_info(cls, document_link: str, do_get: Callable) -> Optional[di @classmethod def parse_license_info( cls, - license_info_document: dict, + license_info_document: Dict[str, Any], license_info_link: str, - checkout_link: str, + checkout_link: Optional[str], ) -> Optional[LicenseData]: """Check the license's attributes passed as parameters: - if they're correct, turn them into a LicenseData object @@ -1078,11 +1112,11 @@ def parse_license_info( def get_license_data( cls, license_info_link: str, - checkout_link: str, - feed_license_identifier: str, - feed_license_expires: str, - feed_concurrency: int, - do_get: Callable, + checkout_link: Optional[str], + feed_license_identifier: Optional[str], + feed_license_expires: Optional[datetime.datetime], + feed_concurrency: Optional[int], + do_get: Callable[..., Tuple[int, Any, bytes]], ) -> Optional[LicenseData]: license_info_document = cls.fetch_license_info(license_info_link, do_get) @@ -1130,8 +1164,12 @@ def get_license_data( @classmethod def _detail_for_elementtree_entry( - cls, parser, entry_tag, feed_url=None, do_get=None - ): + cls, + parser: OPDSXMLParser, + entry_tag: Element, + feed_url: Optional[str] = None, + do_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + ) -> Dict[str, Any]: do_get = do_get or Representation.cautious_http_get # TODO: Review for consistency when updated ODL spec is ready. @@ -1152,7 +1190,7 @@ def _detail_for_elementtree_entry( # By default, dcterms:format includes the media type of a # DRM-free resource. content_type = full_content_type - drm_schemes = [] + drm_schemes: List[str | None] = [] # But it may instead describe an audiobook protected with # the Feedbooks access-control scheme. @@ -1207,10 +1245,10 @@ def _detail_for_elementtree_entry( expires = subtag(terms[0], "odl:expires") if concurrent_checkouts is not None: - concurrent_checkouts = int(concurrent_checkouts) + concurrent_checkouts_int = int(concurrent_checkouts) if expires is not None: - expires = to_utc(dateutil.parser.parse(expires)) + expires_datetime = to_utc(dateutil.parser.parse(expires)) if not odl_status_link: parsed_license = None @@ -1219,8 +1257,8 @@ def _detail_for_elementtree_entry( odl_status_link, checkout_link, identifier, - expires, - concurrent_checkouts, + expires_datetime, + concurrent_checkouts_int, do_get, ) @@ -1248,7 +1286,13 @@ class ODLImportMonitor(OPDSImportMonitor): PROTOCOL = ODLImporter.NAME SERVICE_NAME = "ODL Import Monitor" - def __init__(self, _db, collection, import_class, **import_class_kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + import_class: Type[OPDSImporter], + **import_class_kwargs: Any, + ): # Always force reimport ODL collections to get up to date license information super().__init__( _db, collection, import_class, force_reimport=True, **import_class_kwargs @@ -1262,11 +1306,17 @@ class ODLHoldReaper(CollectionMonitor): SERVICE_NAME = "ODL Hold Reaper" PROTOCOL = ODLAPI.NAME - def __init__(self, _db, collection=None, api=None, **kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + api: Optional[ODLAPI] = None, + **kwargs: Any, + ): super().__init__(_db, collection, **kwargs) self.api = api or ODLAPI(_db, collection) - def run_once(self, progress): + def run_once(self, progress: TimestampData) -> TimestampData: # Find holds that have expired. expired_holds = ( self._db.query(Hold) diff --git a/api/odl2.py b/api/odl2.py index b4b8bc0dbd..b0e7808fe4 100644 --- a/api/odl2.py +++ b/api/odl2.py @@ -1,13 +1,17 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type from flask_babel import lazy_gettext as _ from pydantic import PositiveInt +from sqlalchemy.orm import Session +from webpub_manifest_parser.core.ast import Metadata from webpub_manifest_parser.odl import ODLFeedParserFactory +from webpub_manifest_parser.opds2.ast import OPDS2Feed, OPDS2Publication from webpub_manifest_parser.opds2.registry import OPDS2LinkRelationsRegistry +from api.circulation import HoldInfo from api.circulation_exceptions import PatronHoldLimitReached, PatronLoanLimitReached from api.odl import ODLAPI, ODLImporter, ODLSettings from core.integration.settings import ( @@ -16,14 +20,14 @@ FormField, ) from core.metadata_layer import FormatData -from core.model import Edition, RightsStatus +from core.model import Collection, Edition, Identifier, LicensePool, RightsStatus from core.model.configuration import ExternalIntegration, HasExternalIntegration from core.opds2_import import OPDS2Importer, OPDS2ImportMonitor, RWPMManifestParser from core.util import first_or_default from core.util.datetime_helpers import to_utc if TYPE_CHECKING: - from core.model.patron import Patron + from core.model.patron import Hold, Loan, Patron class ODL2Settings(ODLSettings): @@ -71,16 +75,18 @@ class ODL2API(ODLAPI): NAME = ExternalIntegration.ODL2 @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[ODL2Settings]: return ODL2Settings - def __init__(self, _db, collection): + def __init__(self, _db: Session, collection: Collection) -> None: super().__init__(_db, collection) config = self.configuration() - self.loan_limit = config.loan_limit - self.hold_limit = config.hold_limit + self.loan_limit = config.loan_limit # type: ignore[attr-defined] + self.hold_limit = config.hold_limit # type: ignore[attr-defined] - def _checkout(self, patron: Patron, licensepool, hold=None): + def _checkout( + self, patron: Patron, licensepool: LicensePool, hold: Optional[Hold] = None + ) -> Loan: # If the loan limit is not None or 0 if self.loan_limit: loans = list( @@ -93,7 +99,7 @@ def _checkout(self, patron: Patron, licensepool, hold=None): raise PatronLoanLimitReached(limit=self.loan_limit) return super()._checkout(patron, licensepool, hold) - def _place_hold(self, patron: Patron, licensepool): + def _place_hold(self, patron: Patron, licensepool: LicensePool) -> HoldInfo: # If the hold limit is not None or 0 if self.hold_limit: holds = list( @@ -117,19 +123,19 @@ class ODL2Importer(OPDS2Importer, HasExternalIntegration): NAME = ODL2API.NAME @classmethod - def settings_class(cls): + def settings_class(cls) -> Type[ODL2Settings]: # type: ignore[override] return ODL2Settings def __init__( self, - db, - collection, - parser=None, - data_source_name=None, - identifier_mapping=None, - http_get=None, - content_modifier=None, - map_from_collection=None, + db: Session, + collection: Collection, + parser: Optional[RWPMManifestParser] = None, + data_source_name: str | None = None, + identifier_mapping: Dict[Identifier, Identifier] | None = None, + http_get: Optional[Callable[..., Tuple[int, Any, bytes]]] = None, + content_modifier: Optional[Callable[..., None]] = None, + map_from_collection: Optional[bool] = None, ): """Initialize a new instance of ODL2Importer class. @@ -173,20 +179,19 @@ def __init__( ) self._logger = logging.getLogger(__name__) - def _extract_publication_metadata(self, feed, publication, data_source_name): + def _extract_publication_metadata( + self, + feed: OPDS2Feed, + publication: OPDS2Publication, + data_source_name: Optional[str], + ) -> Metadata: """Extract a Metadata object from webpub-manifest-parser's publication. :param publication: Feed object - :type publication: opds2_ast.OPDS2Feed - :param publication: Publication object - :type publication: opds2_ast.OPDS2Publication - :param data_source_name: Data source's name - :type data_source_name: str :return: Publication's metadata - :rtype: Metadata """ metadata = super()._extract_publication_metadata( feed, publication, data_source_name @@ -195,7 +200,7 @@ def _extract_publication_metadata(self, feed, publication, data_source_name): licenses = [] medium = None - skipped_license_formats = self.configuration().skipped_license_formats + skipped_license_formats = self.configuration().skipped_license_formats # type: ignore[attr-defined] if skipped_license_formats: skipped_license_formats = set(skipped_license_formats) @@ -251,6 +256,7 @@ def _extract_publication_metadata(self, feed, publication, data_source_name): if not medium: medium = Edition.medium_from_media_type(license_format) + drm_schemes: List[str | None] if license_format in ODLImporter.LICENSE_FORMATS: # Special case to handle DeMarque audiobooks which include the protection # in the content type. When we see a license format of @@ -291,9 +297,6 @@ def _extract_publication_metadata(self, feed, publication, data_source_name): return metadata - def external_integration(self, db): - return self.collection.external_integration - class ODL2ImportMonitor(OPDS2ImportMonitor): """Import information from an ODL feed.""" @@ -301,7 +304,13 @@ class ODL2ImportMonitor(OPDS2ImportMonitor): PROTOCOL = ODL2Importer.NAME SERVICE_NAME = "ODL 2.x Import Monitor" - def __init__(self, _db, collection, import_class, **import_class_kwargs): + def __init__( + self, + _db: Session, + collection: Collection, + import_class: Type[ODL2Importer], + **import_class_kwargs: Any, + ) -> None: # Always force reimport ODL collections to get up to date license information super().__init__( _db, collection, import_class, force_reimport=True, **import_class_kwargs diff --git a/core/metadata_layer.py b/core/metadata_layer.py index e73ae1dc8b..4fe5e06edf 100644 --- a/core/metadata_layer.py +++ b/core/metadata_layer.py @@ -523,7 +523,7 @@ class LicenseData(LicenseFunctions): def __init__( self, identifier: str, - checkout_url: str, + checkout_url: Optional[str], status_url: str, status: LicenseStatus, checkouts_available: int, diff --git a/core/model/licensing.py b/core/model/licensing.py index 2d139f3f9d..3511d3b8bd 100644 --- a/core/model/licensing.py +++ b/core/model/licensing.py @@ -5,7 +5,7 @@ import datetime import logging from enum import Enum as PythonEnum -from typing import TYPE_CHECKING, List, Literal, Tuple, overload +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, overload from sqlalchemy import Boolean, Column, DateTime from sqlalchemy import Enum as AlchemyEnum @@ -139,12 +139,12 @@ class License(Base, LicenseFunctions): # One License can have many Loans. loans: Mapped[List[Loan]] = relationship( - "Loan", backref="license", cascade="all, delete-orphan" + "Loan", back_populates="license", cascade="all, delete-orphan" ) __table_args__ = (UniqueConstraint("identifier", "license_pool_id"),) - def loan_to(self, patron: Patron, **kwargs): + def loan_to(self, patron: Patron, **kwargs) -> Tuple[Loan, bool]: loan, is_new = self.license_pool.loan_to(patron, **kwargs) loan.license = self return loan, is_new @@ -1021,7 +1021,7 @@ def loan_to( end=None, fulfillment=None, external_identifier=None, - ): + ) -> Tuple[Loan, bool]: _db = Session.object_session(patron) kwargs = dict(start=start or utc_now(), end=end) loan, is_new = get_one_or_create( @@ -1067,7 +1067,7 @@ def on_hold_to( hold.external_identifier = external_identifier return hold, new - def best_available_license(self): + def best_available_license(self) -> License | None: """Determine the next license that should be lent out for this pool. Time-limited licenses and perpetual licenses are the best. It doesn't matter which @@ -1084,7 +1084,7 @@ def best_available_license(self): The worst option would be pay-per-use, but we don't yet support any distributors that offer that model. """ - best = None + best: Optional[License] = None now = utc_now() for license in self.licenses: @@ -1094,7 +1094,10 @@ def best_available_license(self): active_loan_count = len( [l for l in license.loans if not l.end or l.end > now] ) - if active_loan_count >= license.checkouts_available: + checkouts_available = ( + license.checkouts_available if license.checkouts_available else 0 + ) + if active_loan_count >= checkouts_available: continue if ( @@ -1103,13 +1106,13 @@ def best_available_license(self): or ( license.is_time_limited and best.is_time_limited - and license.expires < best.expires + and license.expires < best.expires # type: ignore[operator] ) or (license.is_perpetual and not best.is_time_limited) or ( license.is_loan_limited and best.is_loan_limited - and license.checkouts_left > best.checkouts_left + and license.checkouts_left > best.checkouts_left # type: ignore[operator] ) ): best = license diff --git a/core/model/patron.py b/core/model/patron.py index 1d4e27bcd0..f5e8ccf8dd 100644 --- a/core/model/patron.py +++ b/core/model/patron.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from core.model.library import Library - from core.model.licensing import LicensePool, LicensePoolDeliveryMechanism + from core.model.licensing import License, LicensePool, LicensePoolDeliveryMechanism from .devicetokens import DeviceToken @@ -544,6 +544,7 @@ class Loan(Base, LoanAndHoldMixin): # It may also be associated with an individual License if the source # provides information about individual licenses. license_id = Column(Integer, ForeignKey("licenses.id"), index=True, nullable=True) + license: Mapped[License] = relationship("License", back_populates="loans") fulfillment_id = Column(Integer, ForeignKey("licensepooldeliveries.id")) fulfillment: Mapped[Optional[LicensePoolDeliveryMechanism]] = relationship( diff --git a/pyproject.toml b/pyproject.toml index 4dcf00fa58..fbe081af42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,9 @@ module = [ "api.circulation", "api.discovery.*", "api.integration.*", + "api.lcp.hash", + "api.odl", + "api.odl2", "core.feed.*", "core.integration.*", "core.model.announcements",