From 2202d4c06024a2b56dfc225bd930ef8f6f61c1bb Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 25 Oct 2023 14:27:46 -0300 Subject: [PATCH 1/6] Type check collection.py --- core/metadata_layer.py | 2 +- core/model/classification.py | 6 +- core/model/collection.py | 210 +++++++++++++++--------- core/model/contributor.py | 4 +- core/model/hassessioncache.py | 24 +-- core/model/integration.py | 7 +- core/model/library.py | 5 +- core/opds_import.py | 2 +- pyproject.toml | 1 + tests/api/mockapi/overdrive.py | 1 + tests/core/models/test_collection.py | 21 ++- tests/core/models/test_configuration.py | 1 + 12 files changed, 175 insertions(+), 109 deletions(-) diff --git a/core/metadata_layer.py b/core/metadata_layer.py index 1b02c7534a..49a26a8b60 100644 --- a/core/metadata_layer.py +++ b/core/metadata_layer.py @@ -554,7 +554,7 @@ def add_to_pool(self, db: Session, pool: LicensePool): class TimestampData: - CLEAR_VALUE = Timestamp.CLEAR_VALUE + CLEAR_VALUE = Timestamp.CLEAR_VALUE # type: ignore[has-type] def __init__( self, start=None, finish=None, achievements=None, counter=None, exception=None diff --git a/core/model/classification.py b/core/model/classification.py index f225b1484d..d4d406716a 100644 --- a/core/model/classification.py +++ b/core/model/classification.py @@ -56,8 +56,8 @@ class Subject(Base): OVERDRIVE = Classifier.OVERDRIVE # Overdrive's classification system BISAC = Classifier.BISAC BIC = Classifier.BIC # BIC Subject Categories - TAG = Classifier.TAG # Folksonomic tags. - FREEFORM_AUDIENCE = Classifier.FREEFORM_AUDIENCE + TAG: str = Classifier.TAG # Folksonomic tags. + FREEFORM_AUDIENCE: str = Classifier.FREEFORM_AUDIENCE NYPL_APPEAL = Classifier.NYPL_APPEAL # Types with terms that are suitable for search. @@ -65,7 +65,7 @@ class Subject(Base): AXIS_360_AUDIENCE = Classifier.AXIS_360_AUDIENCE GRADE_LEVEL = Classifier.GRADE_LEVEL - AGE_RANGE = Classifier.AGE_RANGE + AGE_RANGE: str = Classifier.AGE_RANGE LEXILE_SCORE = Classifier.LEXILE_SCORE ATOS_SCORE = Classifier.ATOS_SCORE INTEREST_LEVEL = Classifier.INTEREST_LEVEL diff --git a/core/model/collection.py b/core/model/collection.py index 801f3095f7..48ff192b94 100644 --- a/core/model/collection.py +++ b/core/model/collection.py @@ -1,8 +1,8 @@ -# Collection, CollectionIdentifier, CollectionMissing from __future__ import annotations +import datetime from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, TypeVar from sqlalchemy import ( Boolean, @@ -15,11 +15,12 @@ exists, func, ) -from sqlalchemy.orm import Mapped, backref, joinedload, mapper, relationship +from sqlalchemy.orm import Mapped, Query, backref, joinedload, mapper, relationship from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import and_, or_ +from core.external_search import ExternalSearchIndex from core.integration.goals import Goals from core.model import Base, create, get_one, get_one_or_create from core.model.configuration import ConfigurationSetting, ExternalIntegration @@ -42,7 +43,10 @@ if TYPE_CHECKING: # This is needed during type checking so we have the # types of related models. - from core.model import Credential, CustomList, Timestamp # noqa: autoflake + from core.model import Credential, CustomList, Timestamp + + +T = TypeVar("T") class Collection(Base, HasSessionCache): @@ -161,14 +165,16 @@ class Collection(Base, HasSessionCache): # every library. GLOBAL_COLLECTION_DATA_SOURCES = [DataSource.ENKI] - def __repr__(self): - return '' % (self.name, self.protocol, self.id) + def __repr__(self) -> str: + return f'' - def cache_key(self): - return (self.name, self.external_integration.protocol) + def cache_key(self) -> Tuple[str | None, str | None]: + return self.name, self.external_integration.protocol @classmethod - def by_name_and_protocol(cls, _db, name, protocol): + def by_name_and_protocol( + cls, _db: Session, name: str, protocol: str + ) -> Tuple[Collection, bool]: """Find or create a Collection with the given name and the given protocol. @@ -178,13 +184,15 @@ def by_name_and_protocol(cls, _db, name, protocol): """ key = (name, protocol) - def lookup_hook(): + def lookup_hook() -> Tuple[Collection, bool]: return cls._by_name_and_protocol(_db, key) return cls.by_cache_key(_db, key, lookup_hook) @classmethod - def _by_name_and_protocol(cls, _db, cache_key): + def _by_name_and_protocol( + cls, _db: Session, cache_key: Tuple[str, str] + ) -> Tuple[Collection, bool]: """Find or create a Collection with the given name and the given protocol. @@ -215,7 +223,7 @@ def _by_name_and_protocol(cls, _db, cache_key): return collection, is_new @classmethod - def by_protocol(cls, _db, protocol): + def by_protocol(cls, _db: Session, protocol: str | None) -> Query[Collection]: """Query collections that get their licenses through the given protocol. Collections marked for deletion are not included. @@ -239,13 +247,16 @@ def by_protocol(cls, _db, protocol): return qu @classmethod - def by_datasource(cls, _db, data_source): + def by_datasource( + cls, _db: Session, data_source: DataSource | str + ) -> Query[Collection]: """Query collections that are associated with the given DataSource. Collections marked for deletion are not included. """ - if isinstance(data_source, DataSource): - data_source = data_source.name + data_source_name = ( + data_source.name if isinstance(data_source, DataSource) else data_source + ) qu = ( _db.query(cls) @@ -257,23 +268,29 @@ def by_datasource(cls, _db, data_source): IntegrationConfiguration.settings_dict[ Collection.DATA_SOURCE_NAME_SETTING ].astext - == data_source + == data_source_name ) .filter(Collection.marked_for_deletion == False) ) return qu @hybrid_property - def protocol(self): + def protocol(self) -> str: """What protocol do we need to use to get licenses for this collection? """ - return ( - self.integration_configuration and self.integration_configuration.protocol - ) + if self.integration_configuration is None: + raise ValueError("Collection has no integration configuration.") + + if self.integration_configuration.protocol is None: + raise ValueError( + "Collection has integration configuration but no protocol." + ) + + return self.integration_configuration.protocol @protocol.setter - def protocol(self, new_protocol): + def protocol(self, new_protocol: str) -> None: """Modify the protocol in use by this Collection.""" if self.parent and self.parent.protocol != new_protocol: raise ValueError( @@ -285,14 +302,14 @@ def protocol(self, new_protocol): child.protocol = new_protocol @hybrid_property - def primary_identifier_source(self): + def primary_identifier_source(self) -> str | None: """Identify if should try to use another identifier than """ return self.integration_configuration.settings_dict.get( ExternalIntegration.PRIMARY_IDENTIFIER_SOURCE ) @primary_identifier_source.setter - def primary_identifier_source(self, new_primary_identifier_source): + def primary_identifier_source(self, new_primary_identifier_source: str) -> None: """Modify the primary identifier source in use by this Collection.""" self.integration_configuration.settings_dict = ( self.integration_configuration.settings_dict.copy() @@ -311,7 +328,9 @@ def primary_identifier_source(self, new_primary_identifier_source): EBOOK_LOAN_DURATION_KEY = "ebook_loan_duration" STANDARD_DEFAULT_LOAN_PERIOD = 21 - def default_loan_period(self, library, medium=EditionConstants.BOOK_MEDIUM): + def default_loan_period( + self, library: Library, medium: str = EditionConstants.BOOK_MEDIUM + ) -> int: """Until we hear otherwise from the license provider, we assume that someone who borrows a non-open-access item from this collection has it for this number of days. @@ -323,7 +342,7 @@ def default_loan_period(self, library, medium=EditionConstants.BOOK_MEDIUM): return value @classmethod - def loan_period_key(cls, medium=EditionConstants.BOOK_MEDIUM): + def loan_period_key(cls, medium: str = EditionConstants.BOOK_MEDIUM) -> str: if medium == EditionConstants.AUDIO_MEDIUM: return cls.AUDIOBOOK_LOAN_DURATION_KEY else: @@ -331,29 +350,33 @@ def loan_period_key(cls, medium=EditionConstants.BOOK_MEDIUM): def default_loan_period_setting( self, - library, - medium=EditionConstants.BOOK_MEDIUM, - ): + library: Library, + medium: str = EditionConstants.BOOK_MEDIUM, + ) -> Optional[int]: """Until we hear otherwise from the license provider, we assume that someone who borrows a non-open-access item from this collection has it for this number of days. """ key = self.loan_period_key(medium) + if library.id is None: + return None + config = self.integration_configuration.for_library(library.id) + if config is None: + return None - if config: - return config.settings_dict.get(key) + return config.settings_dict.get(key) DEFAULT_RESERVATION_PERIOD_KEY = "default_reservation_period" STANDARD_DEFAULT_RESERVATION_PERIOD = 3 - def _set_settings(self, **kwargs): + def _set_settings(self, **kwargs: Any) -> None: settings_dict = self.integration_configuration.settings_dict.copy() settings_dict.update(kwargs) self.integration_configuration.settings_dict = settings_dict @hybrid_property - def default_reservation_period(self): + def default_reservation_period(self) -> int: """Until we hear otherwise from the license provider, we assume that someone who puts an item on hold has this many days to check it out before it goes to the next person in line. @@ -366,7 +389,7 @@ def default_reservation_period(self): ) @default_reservation_period.setter - def default_reservation_period(self, new_value): + def default_reservation_period(self, new_value: int) -> None: new_value = int(new_value) self._set_settings(**{self.DEFAULT_RESERVATION_PERIOD_KEY: new_value}) @@ -395,7 +418,7 @@ def default_audience(self, new_value: str) -> None: """ self._set_settings(**{self.DEFAULT_AUDIENCE_KEY: str(new_value)}) - def create_external_integration(self, protocol): + def create_external_integration(self, protocol: str) -> ExternalIntegration: """Create an ExternalIntegration for this Collection. To be used immediately after creating a new Collection, @@ -424,7 +447,9 @@ def create_external_integration(self, protocol): self.external_integration_id = external_integration.id return external_integration - def create_integration_configuration(self, protocol): + def create_integration_configuration( + self, protocol: str + ) -> IntegrationConfiguration: _db = Session.object_session(self) goal = Goals.LICENSE_GOAL if self.integration_configuration_id: @@ -466,8 +491,9 @@ def external_integration(self) -> ExternalIntegration: return self._external_integration @property - def unique_account_id(self): + def unique_account_id(self) -> str: """Identifier that uniquely represents this Collection of works""" + unique_account_id: str | None if ( self.data_source and self.data_source.name in self.GLOBAL_COLLECTION_DATA_SOURCES @@ -488,7 +514,7 @@ def unique_account_id(self): return unique_account_id @hybrid_property - def data_source(self): + def data_source(self) -> DataSource | None: """Find the data source associated with this Collection. Bibliographic metadata obtained through the collection @@ -502,7 +528,11 @@ def data_source(self): the data source is a Collection-specific setting. """ data_source = None - name = ExternalIntegration.DATA_SOURCE_FOR_LICENSE_PROTOCOL.get(self.protocol) + name = None + if self.protocol is not None: + name = ExternalIntegration.DATA_SOURCE_FOR_LICENSE_PROTOCOL.get( + self.protocol + ) if not name: name = self.integration_configuration.settings_dict.get( Collection.DATA_SOURCE_NAME_SETTING @@ -513,29 +543,38 @@ def data_source(self): return data_source @data_source.setter - def data_source(self, new_value): - if isinstance(new_value, DataSource): - new_value = new_value.name - if self.protocol == new_value: + def data_source(self, new_value: DataSource | str) -> None: + new_datasource_name = ( + new_value.name if isinstance(new_value, DataSource) else new_value + ) + + if self.protocol == new_datasource_name: return # Only set a DataSource for Collections that don't have an # implied source. if self.protocol not in ExternalIntegration.DATA_SOURCE_FOR_LICENSE_PROTOCOL: - if new_value is not None: - new_value = str(new_value) - self._set_settings(**{Collection.DATA_SOURCE_NAME_SETTING: new_value}) + if new_datasource_name is not None: + new_datasource_name = str(new_datasource_name) + self._set_settings( + **{Collection.DATA_SOURCE_NAME_SETTING: new_datasource_name} + ) @property - def parents(self): - if self.parent_id: - _db = Session.object_session(self) - parent = Collection.by_id(_db, self.parent_id) - yield parent - yield from parent.parents + def parents(self) -> Generator[Collection, None, None]: + if not self.parent_id: + return None + + _db = Session.object_session(self) + parent = Collection.by_id(_db, self.parent_id) + if parent is None: + return None + + yield parent + yield from parent.parents @property - def metadata_identifier(self): + def metadata_identifier(self) -> str: """Identifier based on collection details that uniquely represents this Collection on the metadata wrangler. This identifier is composed of the Collection protocol and account identifier. @@ -558,13 +597,13 @@ def metadata_identifier(self): protocol = encode(self.protocol) metadata_identifier = protocol + ":" + account_id - return encode(metadata_identifier) + return encode(metadata_identifier) # type: ignore[no-any-return] - def disassociate_library(self, library): + def disassociate_library(self, library: Library) -> None: """Disassociate a Library from this Collection and delete any relevant ConfigurationSettings. """ - if library is None or not library in self.libraries: + if library is None or library not in self.libraries: # No-op. return @@ -602,7 +641,7 @@ def disassociate_library(self, library): self.libraries.remove(library) @classmethod - def _decode_metadata_identifier(cls, metadata_identifier): + def _decode_metadata_identifier(cls, metadata_identifier: str) -> Tuple[str, str]: """Invert the metadata_identifier property.""" if not metadata_identifier: raise ValueError("No metadata identifier provided.") @@ -619,7 +658,12 @@ def _decode_metadata_identifier(cls, metadata_identifier): return protocol, account_id @classmethod - def from_metadata_identifier(cls, _db, metadata_identifier, data_source=None): + def from_metadata_identifier( + cls, + _db: Session, + metadata_identifier: str, + data_source: DataSource | str | None = None, + ) -> Tuple[Collection, bool]: """Finds or creates a Collection on the metadata wrangler, based on its unique metadata_identifier. """ @@ -639,22 +683,26 @@ def from_metadata_identifier(cls, _db, metadata_identifier, data_source=None): # identifier. Give it an ExternalIntegration with the # corresponding protocol, and set its data source and # external_account_id. - collection, is_new = create(_db, Collection, name=metadata_identifier) - collection.create_external_integration(protocol) - collection.create_integration_configuration(protocol) + new_collection, is_new = create(_db, Collection, name=metadata_identifier) + new_collection.create_external_integration(protocol) + new_collection.create_integration_configuration(protocol) + collection = new_collection if protocol == ExternalIntegration.OPDS_IMPORT: # For OPDS Import collections only, we store the URL to # the OPDS feed (the "account ID") and the data source. collection.external_account_id = account_id - if data_source and not isinstance(data_source, DataSource): - data_source = DataSource.lookup(_db, data_source, autocreate=True) - collection.data_source = data_source + if isinstance(data_source, DataSource): + collection.data_source = data_source + elif data_source is not None: + collection.data_source = DataSource.lookup( + _db, data_source, autocreate=True + ) return collection, is_new @property - def pools_with_no_delivery_mechanisms(self): + def pools_with_no_delivery_mechanisms(self) -> Query[LicensePool]: """Find all LicensePools in this Collection that have no delivery mechanisms whatsoever. @@ -662,9 +710,9 @@ def pools_with_no_delivery_mechanisms(self): """ _db = Session.object_session(self) qu = LicensePool.with_no_delivery_mechanisms(_db) - return qu.filter(LicensePool.collection == self) + return qu.filter(LicensePool.collection == self) # type: ignore[no-any-return] - def explain(self, include_secrets=False): + def explain(self, include_secrets: bool = False) -> List[str]: """Create a series of human-readable strings to explain a collection's settings. @@ -693,11 +741,11 @@ def explain(self, include_secrets=False): lines.append(f'Setting "{name}": "{value}"') return lines - def catalog_identifier(self, identifier): + def catalog_identifier(self, identifier: Identifier) -> None: """Inserts an identifier into a catalog""" self.catalog_identifiers([identifier]) - def catalog_identifiers(self, identifiers): + def catalog_identifiers(self, identifiers: List[Identifier]) -> None: """Inserts identifiers into the catalog""" if not identifiers: # Nothing to do. @@ -707,7 +755,7 @@ def catalog_identifiers(self, identifiers): already_in_catalog = ( _db.query(Identifier) .join(CollectionIdentifier) - .filter(CollectionIdentifier.collection_id == self.id) + .filter(CollectionIdentifier.collection_id == self.id) # type: ignore[attr-defined] .filter(Identifier.id.in_([x.id for x in identifiers])) .all() ) @@ -720,7 +768,9 @@ def catalog_identifiers(self, identifiers): _db.bulk_insert_mappings(CollectionIdentifier, new_catalog_entries) _db.commit() - def unresolved_catalog(self, _db, data_source_name, operation): + def unresolved_catalog( + self, _db: Session, data_source_name: str, operation: str + ) -> Query[Identifier]: """Returns a query with all identifiers in a Collection's catalog that have unsuccessfully attempted resolution. This method is used on the metadata wrangler. @@ -739,14 +789,16 @@ def unresolved_catalog(self, _db, data_source_name, operation): .outerjoin(Identifier.licensed_through) .outerjoin(Identifier.coverage_records) .outerjoin(LicensePool.work) - .outerjoin(Identifier.collections) + .outerjoin(Identifier.collections) # type: ignore[attr-defined] .filter(Collection.id == self.id, is_not_resolved, Work.id == None) .order_by(Identifier.id) ) return query - def isbns_updated_since(self, _db, timestamp): + def isbns_updated_since( + self, _db: Session, timestamp: datetime.datetime | None + ) -> Query[Identifier]: """Finds all ISBNs in a collection's catalog that have been updated since the timestamp but don't have a Work to show for it. Used in the metadata wrangler. @@ -755,7 +807,7 @@ def isbns_updated_since(self, _db, timestamp): """ isbns = ( _db.query(Identifier, func.max(CoverageRecord.timestamp).label("latest")) - .join(Identifier.collections) + .join(Identifier.collections) # type: ignore[attr-defined] .join(Identifier.coverage_records) .outerjoin(Identifier.licensed_through) .group_by(Identifier.id) @@ -777,11 +829,11 @@ def isbns_updated_since(self, _db, timestamp): @classmethod def restrict_to_ready_deliverable_works( cls, - query, - collection_ids=None, - show_suppressed=False, - allow_holds=True, - ): + query: Query[T], + collection_ids: List[int] | None = None, + show_suppressed: bool = False, + allow_holds: bool = True, + ) -> Query[T]: """Restrict a query to show only presentation-ready works present in an appropriate collection which the default client can fulfill. @@ -856,7 +908,7 @@ def restrict_to_ready_deliverable_works( ) return query - def delete(self, search_index=None): + def delete(self, search_index: ExternalSearchIndex | None = None) -> None: """Delete a collection. Collections can have hundreds of thousands of diff --git a/core/model/contributor.py b/core/model/contributor.py index 284d8e55b1..00e2ebf93a 100644 --- a/core/model/contributor.py +++ b/core/model/contributor.py @@ -3,7 +3,7 @@ import logging import re -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Set from sqlalchemy import Column, ForeignKey, Integer, Unicode, UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY, JSON @@ -93,7 +93,7 @@ class Contributor(Base): COPYRIGHT_HOLDER_ROLE = "Copyright holder" TRANSCRIBER_ROLE = "Transcriber" DESIGNER_ROLE = "Designer" - AUTHOR_ROLES = {PRIMARY_AUTHOR_ROLE, AUTHOR_ROLE} + AUTHOR_ROLES: Set[str] = {PRIMARY_AUTHOR_ROLE, AUTHOR_ROLE} # Map our recognized roles to MARC relators. # https://www.loc.gov/marc/relators/relaterm.html diff --git a/core/model/hassessioncache.py b/core/model/hassessioncache.py index 1f17ef9bd5..f17fa48bfa 100644 --- a/core/model/hassessioncache.py +++ b/core/model/hassessioncache.py @@ -6,7 +6,7 @@ from abc import abstractmethod from collections import namedtuple from types import SimpleNamespace -from typing import Callable, Hashable +from typing import Callable, Hashable, TypeVar from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Mapped, Session @@ -21,6 +21,8 @@ CacheTuple = namedtuple("CacheTuple", ["id", "key", "stats"]) +T = TypeVar("T") + class HasSessionCache: CACHE_ATTRIBUTE = "_palace_cache" @@ -94,8 +96,8 @@ def _cache_lookup( cache: CacheTuple, cache_name: str, cache_key: Hashable, - cache_miss_hook: Callable[[], tuple[Self | None, bool]], - ) -> tuple[Self | None, bool]: + cache_miss_hook: Callable[[], tuple[T, bool]], + ) -> tuple[T, bool]: """Helper method used by both by_id and by_cache_key. Looks up `cache_key` in the `cache_name` property of `cache`, returning @@ -142,20 +144,22 @@ def _cache_from_session(cls, _db: Session) -> CacheTuple: def by_id(cls, db: Session, id: int) -> Self | None: """Look up an item by its unique database ID.""" cache = cls._cache_from_session(db) - - def lookup_hook(): # type: ignore[no-untyped-def] - return get_one(db, cls, id=id), False - - obj, _ = cls._cache_lookup(db, cache, "id", id, lookup_hook) + obj, _ = cls._cache_lookup( + db, cache, "id", id, lambda: cls._by_id_lookup_hook(db, id) + ) return obj + @classmethod + def _by_id_lookup_hook(cls, db: Session, id: int) -> tuple[Self | None, bool]: + return get_one(db, cls, id=id), False + @classmethod def by_cache_key( cls, db: Session, cache_key: Hashable, - cache_miss_hook: Callable[[], tuple[Self | None, bool]], - ) -> tuple[Self | None, bool]: + cache_miss_hook: Callable[[], tuple[T, bool]], + ) -> tuple[T, bool]: """Look up an item by its cache key.""" cache = cls._cache_from_session(db) return cls._cache_lookup(db, cache, "key", cache_key, cache_miss_hook) diff --git a/core/model/integration.py b/core/model/integration.py index eaa4e43f6a..ebac448cb6 100644 --- a/core/model/integration.py +++ b/core/model/integration.py @@ -72,14 +72,17 @@ def for_library( @overload def for_library( - self, library_id: int, create: Literal[False] = False + self, library_id: int | None, create: bool = False ) -> IntegrationLibraryConfiguration | None: ... def for_library( - self, library_id: int, create: bool = False + self, library_id: int | None, create: bool = False ) -> IntegrationLibraryConfiguration | None: """Fetch the library configuration specifically by library_id""" + if library_id is None: + return None + for config in self.library_configurations: if config.library_id == library_id: return config diff --git a/core/model/library.py b/core/model/library.py index 9b9e563f2b..9f66939124 100644 --- a/core/model/library.py +++ b/core/model/library.py @@ -328,7 +328,8 @@ def enabled_facets(self, group_name: str) -> List[str]: if group_name == FacetConstants.COLLECTION_NAME_FACETS_GROUP_NAME: enabled = [] for collection in self.collections: - enabled.append(collection.name) + if collection.name is not None: + enabled.append(collection.name) return enabled return getattr(self.settings, f"facets_enabled_{group_name}") # type: ignore[no-any-return] @@ -386,7 +387,7 @@ def restrict_to_ready_deliverable_works( collection_ids = collection_ids or [ x.id for x in self.all_collections if x.id is not None ] - return Collection.restrict_to_ready_deliverable_works( # type: ignore[no-any-return] + return Collection.restrict_to_ready_deliverable_works( query, collection_ids=collection_ids, show_suppressed=show_suppressed, diff --git a/core/opds_import.py b/core/opds_import.py index 49dd713fe5..b85098398a 100644 --- a/core/opds_import.py +++ b/core/opds_import.py @@ -1880,7 +1880,7 @@ def data_source(self, collection: Collection) -> Optional[DataSource]: By default, this URL is stored as a setting on the collection, but subclasses may hard-code it. """ - return collection.data_source # type: ignore[no-any-return] + return collection.data_source def feed_contains_new_data(self, feed: bytes | str) -> bool: """Does the given feed contain any entries that haven't been imported diff --git a/pyproject.toml b/pyproject.toml index acb3dc7b98..02490c5d11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ module = [ "core.feed.*", "core.integration.*", "core.model.announcements", + "core.model.collection", "core.model.hassessioncache", "core.model.integration", "core.model.library", diff --git a/tests/api/mockapi/overdrive.py b/tests/api/mockapi/overdrive.py index d7ea7d7f0f..e96c927eeb 100644 --- a/tests/api/mockapi/overdrive.py +++ b/tests/api/mockapi/overdrive.py @@ -70,6 +70,7 @@ def mock_collection( } library.collections.append(collection) db = DatabaseTransactionFixture + assert library.id is not None db.set_settings(config.for_library(library.id, create=True), ils_name=ils_name) _db.refresh(config) return collection diff --git a/tests/core/models/test_collection.py b/tests/core/models/test_collection.py index 1986abf796..4e529e74ea 100644 --- a/tests/core/models/test_collection.py +++ b/tests/core/models/test_collection.py @@ -269,23 +269,24 @@ def test_data_source(self, example_collection_fixture: ExampleCollectionFixture) bibliotheca = db.collection(protocol=ExternalIntegration.BIBLIOTHECA) # The rote data_source is returned for the obvious collection. + assert bibliotheca.data_source is not None assert DataSource.BIBLIOTHECA == bibliotheca.data_source.name # The less obvious OPDS collection doesn't have a DataSource. assert None == opds.data_source # Trying to change the Bibliotheca collection's data_source does nothing. - bibliotheca.data_source = DataSource.AXIS_360 + bibliotheca.data_source = DataSource.AXIS_360 # type: ignore[assignment] assert isinstance(bibliotheca.data_source, DataSource) assert DataSource.BIBLIOTHECA == bibliotheca.data_source.name # Trying to change the opds collection's data_source is fine. - opds.data_source = DataSource.PLYMPTON + opds.data_source = DataSource.PLYMPTON # type: ignore[assignment] assert isinstance(opds.data_source, DataSource) assert DataSource.PLYMPTON == opds.data_source.name # Resetting it to something else is fine. - opds.data_source = DataSource.OA_CONTENT_SERVER + opds.data_source = DataSource.OA_CONTENT_SERVER # type: ignore[assignment] assert isinstance(opds.data_source, DataSource) assert DataSource.OA_CONTENT_SERVER == opds.data_source.name @@ -520,7 +521,8 @@ def new_data_source(): # Because this isn't an OPDS collection, the external account # ID is not stored, the data source is the default source for # the protocol, and no new data source was created. - assert None == mirror_collection.external_account_id + assert mirror_collection.external_account_id is None + assert mirror_collection.data_source is not None assert DataSource.OVERDRIVE == mirror_collection.data_source.name assert None == new_data_source() @@ -529,6 +531,7 @@ def new_data_source(): mirror_collection = create( db.session, Collection, name=collection.metadata_identifier )[0] + assert collection.protocol is not None mirror_collection.create_external_integration(collection.protocol) mirror_collection.create_integration_configuration(collection.protocol) @@ -728,7 +731,7 @@ def assert_isbns(expected, result_query): assert_isbns([i2, i1], updated_isbns) # That CoverageRecord timestamp is also returned. - i1_timestamp = updated_isbns[1][1] + i1_timestamp = updated_isbns[1][1] # type: ignore[index] assert isinstance(i1_timestamp, datetime.datetime) assert i1_oclc_record.timestamp == i1_timestamp @@ -736,8 +739,8 @@ def assert_isbns(expected, result_query): # then will be returned. timestamp = utc_now() i1.coverage_records[0].timestamp = utc_now() - updated_isbns = test_collection.isbns_updated_since(db.session, timestamp) - assert_isbns([i1], updated_isbns) + updated_isbns_2 = test_collection.isbns_updated_since(db.session, timestamp) + assert_isbns([i1], updated_isbns_2) # Prepare an ISBN associated with a Work. work = db.work(with_license_pool=True) @@ -745,8 +748,8 @@ def assert_isbns(expected, result_query): i2.coverage_records[0].timestamp = utc_now() # ISBNs that have a Work will be ignored. - updated_isbns = test_collection.isbns_updated_since(db.session, timestamp) - assert_isbns([i1], updated_isbns) + updated_isbns_3 = test_collection.isbns_updated_since(db.session, timestamp) + assert_isbns([i1], updated_isbns_3) def test_custom_lists(self, example_collection_fixture: ExampleCollectionFixture): db = example_collection_fixture.database_fixture diff --git a/tests/core/models/test_configuration.py b/tests/core/models/test_configuration.py index 7d10ba373d..4ac3f35691 100644 --- a/tests/core/models/test_configuration.py +++ b/tests/core/models/test_configuration.py @@ -515,6 +515,7 @@ def test_data_source( # For most collections, the protocol determines the # data source. collection = db.collection(protocol=ExternalIntegration.OVERDRIVE) + assert collection.data_source is not None assert DataSource.OVERDRIVE == collection.data_source.name # For OPDS Import collections, data source is a setting which From f6e07c2fc5cff475e6ea59440a7737efef38029a Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 25 Oct 2023 14:32:32 -0300 Subject: [PATCH 2/6] Fix import issue --- core/model/collection.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/model/collection.py b/core/model/collection.py index 48ff192b94..47053908fe 100644 --- a/core/model/collection.py +++ b/core/model/collection.py @@ -20,7 +20,6 @@ from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import and_, or_ -from core.external_search import ExternalSearchIndex from core.integration.goals import Goals from core.model import Base, create, get_one, get_one_or_create from core.model.configuration import ConfigurationSetting, ExternalIntegration @@ -41,8 +40,7 @@ from core.util.string_helpers import base64 if TYPE_CHECKING: - # This is needed during type checking so we have the - # types of related models. + from core.external_search import ExternalSearchIndex from core.model import Credential, CustomList, Timestamp From d501e5a01113eb9e58863830da6e283dbfb410d0 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 25 Oct 2023 14:35:30 -0300 Subject: [PATCH 3/6] Remove unused ignore --- core/metadata_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/metadata_layer.py b/core/metadata_layer.py index 49a26a8b60..1b02c7534a 100644 --- a/core/metadata_layer.py +++ b/core/metadata_layer.py @@ -554,7 +554,7 @@ def add_to_pool(self, db: Session, pool: LicensePool): class TimestampData: - CLEAR_VALUE = Timestamp.CLEAR_VALUE # type: ignore[has-type] + CLEAR_VALUE = Timestamp.CLEAR_VALUE def __init__( self, start=None, finish=None, achievements=None, counter=None, exception=None From 84350bb00e68e11b54ce3006713520d375c5321f Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 25 Oct 2023 16:18:24 -0300 Subject: [PATCH 4/6] Make sure multitest has integration configuration. --- tests/api/test_controller_multilib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/api/test_controller_multilib.py b/tests/api/test_controller_multilib.py index da69c34d9f..4e725fce1c 100644 --- a/tests/api/test_controller_multilib.py +++ b/tests/api/test_controller_multilib.py @@ -21,6 +21,7 @@ def make_default_collection(_db, library): name=f"{controller_fixture.db.fresh_str()} (for multi-library test)", ) collection.create_external_integration(ExternalIntegration.OPDS_IMPORT) + collection.create_integration_configuration(ExternalIntegration.OPDS_IMPORT) library.collections.append(collection) return collection From 28557ebbc96bd68819fb68c6fb3594802ab1e8e7 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 25 Oct 2023 16:52:46 -0300 Subject: [PATCH 5/6] Add another test --- tests/api/test_controller_scopedsession.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/api/test_controller_scopedsession.py b/tests/api/test_controller_scopedsession.py index 53afdaa5cd..797a7b25f6 100644 --- a/tests/api/test_controller_scopedsession.py +++ b/tests/api/test_controller_scopedsession.py @@ -60,6 +60,7 @@ def make_default_collection(self, session: Session, library): name=self.fresh_id() + " (collection for scoped session)", ) collection.create_external_integration(ExternalIntegration.OPDS_IMPORT) + collection.create_integration_configuration(ExternalIntegration.OPDS_IMPORT) library.collections.append(collection) return collection From c7a515471a5419906d9d88c6e0fdf417c8d53deb Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 25 Oct 2023 17:34:56 -0300 Subject: [PATCH 6/6] Add some more tests --- tests/core/models/test_collection.py | 26 +++++++++++++++++++ .../models/test_integration_configuration.py | 5 +++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/core/models/test_collection.py b/tests/core/models/test_collection.py index 4e529e74ea..cde4aeeadf 100644 --- a/tests/core/models/test_collection.py +++ b/tests/core/models/test_collection.py @@ -229,6 +229,32 @@ def test_unique_account_id( enki_child.parent = enki assert DataSource.ENKI + "+enkichild" == enki_child.unique_account_id + def test_get_protocol(self, db: DatabaseTransactionFixture): + test_collection = db.collection() + integration = test_collection.integration_configuration + test_collection.integration_configuration = None + + # A collection with no associated ExternalIntegration has no protocol. + with pytest.raises(ValueError) as excinfo: + getattr(test_collection, "protocol") + + assert "Collection has no integration configuration" in str(excinfo.value) + + integration.protocol = None + test_collection.integration_configuration = integration + + # If a collection has an integration that doesn't have a protocol set, + # it has no protocol, so we get an exception. + with pytest.raises(ValueError) as excinfo: + getattr(test_collection, "protocol") + + assert "Collection has integration configuration but no protocol" in str( + excinfo.value + ) + + integration.protocol = "test protocol" + assert test_collection.protocol == "test protocol" + def test_change_protocol( self, example_collection_fixture: ExampleCollectionFixture ): diff --git a/tests/core/models/test_integration_configuration.py b/tests/core/models/test_integration_configuration.py index 50c3c5f4ed..251487423e 100644 --- a/tests/core/models/test_integration_configuration.py +++ b/tests/core/models/test_integration_configuration.py @@ -16,8 +16,11 @@ def test_for_library(seslf, db: DatabaseTransactionFixture): library = db.default_library() assert library.id is not None + # No library ID provided + assert config.for_library(None) is None + # No library config exists - assert config.for_library(library.id) == None + assert config.for_library(library.id) is None # This should create a new config libconfig = config.for_library(library.id, create=True)