diff --git a/alembic/versions/20241029_3faa5bba3ddf_add_index_ix_licensepools_collection_id_.py b/alembic/versions/20241029_3faa5bba3ddf_add_index_ix_licensepools_collection_id_.py new file mode 100644 index 000000000..9d26eb814 --- /dev/null +++ b/alembic/versions/20241029_3faa5bba3ddf_add_index_ix_licensepools_collection_id_.py @@ -0,0 +1,28 @@ +"""Add index ix_licensepools_collection_id_work_id + +Revision ID: 3faa5bba3ddf +Revises: 1938277e993f +Create Date: 2024-10-29 15:29:56.588830+00:00 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3faa5bba3ddf" +down_revision = "1938277e993f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_index( + "ix_licensepools_collection_id_work_id", + "licensepools", + ["collection_id", "work_id"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("ix_licensepools_collection_id_work_id", table_name="licensepools") diff --git a/pyproject.toml b/pyproject.toml index e50615c60..80b1f2742 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,12 +94,12 @@ module = [ "palace.manager.api.metadata.*", "palace.manager.api.odl.*", "palace.manager.api.opds_for_distributors", - "palace.manager.core.marc", "palace.manager.core.opds2_import", "palace.manager.core.opds_import", "palace.manager.core.selftest", "palace.manager.feed.*", "palace.manager.integration.*", + "palace.manager.marc.*", "palace.manager.opds.*", "palace.manager.scripts.initialization", "palace.manager.scripts.rotate_jwe_key", diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py index 09bb2d138..6bd84930e 100644 --- a/src/palace/manager/celery/tasks/marc.py +++ b/src/palace/manager/celery/tasks/marc.py @@ -1,16 +1,24 @@ import datetime +from contextlib import ExitStack +from tempfile import TemporaryFile from typing import Any from celery import shared_task +from pydantic import TypeAdapter from palace.manager.celery.task import Task from palace.manager.marc.exporter import LibraryInfo, MarcExporter -from palace.manager.marc.uploader import MarcUploadManager +from palace.manager.marc.uploader import MarcUploadManager, UploadContext from palace.manager.service.celery.celery import QueueNames -from palace.manager.service.redis.models.marc import ( - MarcFileUploadSession, - MarcFileUploadState, +from palace.manager.service.redis.models.lock import RedisLock +from palace.manager.service.redis.redis import Redis +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.identifier import ( + Identifier, + RecursiveEquivalencyCache, ) +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import utc_now @@ -26,73 +34,84 @@ def marc_export(task: Task, force: bool = False) -> None: start_time = utc_now() collections = MarcExporter.enabled_collections(session, registry) for collection in collections: - # Collection.id should never be able to be None here, but mypy doesn't know that. - # So we assert it for mypy's benefit. - assert collection.id is not None - upload_session = MarcFileUploadSession( - task.services.redis.client(), collection.id + libraries_info = MarcExporter.enabled_libraries( + session, registry, collection.id ) - with upload_session.lock() as acquired: - if not acquired: - task.log.info( - f"Skipping collection {collection.name} ({collection.id}) because another task holds its lock." - ) - continue + needs_update = any(info.needs_update for info in libraries_info) or force - if ( - upload_state := upload_session.state() - ) != MarcFileUploadState.INITIAL: - task.log.info( - f"Skipping collection {collection.name} ({collection.id}) because it is already being " - f"processed (state: {upload_state})." - ) - continue - - libraries_info = MarcExporter.enabled_libraries( - session, registry, collection.id + if not needs_update: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it has been updated recently." ) - needs_update = ( - any(info.needs_update for info in libraries_info) or force + continue + + if not MarcExporter.query_works( + session, + collection.id, + batch_size=1, + ): + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it has no works." ) + continue - if not needs_update: - task.log.info( - f"Skipping collection {collection.name} ({collection.id}) because it has been updated recently." - ) - continue + task.log.info( + f"Generating MARC records for collection {collection.name} ({collection.id})." + ) - works = MarcExporter.query_works( + marc_export_collection.delay( + collection_id=collection.id, + collection_name=collection.name, + start_time=start_time, + libraries=[l.model_dump() for l in libraries_info], + ) + + needs_delta = [l.model_dump() for l in libraries_info if l.last_updated] + if needs_delta: + min_last_updated = min( + [l.last_updated for l in libraries_info if l.last_updated] + ) + if not MarcExporter.query_works( session, collection.id, - work_id_offset=0, batch_size=1, - ) - if not works: + last_updated=min_last_updated, + ): task.log.info( - f"Skipping collection {collection.name} ({collection.id}) because it has no works." + f"Skipping delta for collection {collection.name} ({collection.id}) " + f"because no works have been updated." + ) + else: + marc_export_collection.delay( + collection_id=collection.id, + collection_name=collection.name, + start_time=start_time, + libraries=needs_delta, + delta=True, ) - continue - task.log.info( - f"Generating MARC records for collection {collection.name} ({collection.id})." - ) - upload_session.set_state(MarcFileUploadState.QUEUED) - marc_export_collection.delay( - collection_id=collection.id, - start_time=start_time, - libraries=[l.model_dump() for l in libraries_info], - ) + +def marc_export_collection_lock( + client: Redis, collection_id: int, delta: bool = False +) -> RedisLock: + return RedisLock( + client, + ["MarcUpload", Collection.redis_key_from_id(collection_id), f"Delta::{delta}"], + lock_timeout=datetime.timedelta(minutes=20), + ) @shared_task(queue=QueueNames.default, bind=True) def marc_export_collection( task: Task, collection_id: int, + collection_name: str, start_time: datetime.datetime, libraries: list[dict[str, Any]], - batch_size: int = 500, + context: dict[int, dict[str, Any]] | None = None, last_work_id: int | None = None, - update_number: int = 0, + batch_size: int = 1000, + delta: bool = False, ) -> None: """ Export MARC records for a single collection. @@ -104,64 +123,140 @@ def marc_export_collection( base_url = task.services.config.sitewide.base_url() storage_service = task.services.storage.public() - libraries_info = [LibraryInfo.model_validate(l) for l in libraries] - upload_manager = MarcUploadManager( - storage_service, - MarcFileUploadSession( - task.services.redis.client(), collection_id, update_number - ), + + # Parse data into pydantic models + libraries_info = TypeAdapter(list[LibraryInfo]).validate_python(libraries) + context_parsed = TypeAdapter(dict[int, UploadContext]).validate_python( + context or {} ) - with upload_manager.begin(): - if not upload_manager.locked: + + lock = marc_export_collection_lock( + task.services.redis.client(), collection_id, delta + ) + + with lock.lock() as locked: + if not locked: task.log.info( f"Skipping collection {collection_id} because another task is already processing it." ) return - with task.session() as session: - works = MarcExporter.query_works( - session, - collection_id, - work_id_offset=last_work_id, - batch_size=batch_size, - ) - for work in works: - MarcExporter.process_work( - work, libraries_info, base_url, upload_manager=upload_manager + with ExitStack() as stack, task.transaction() as session: + files = { + library: stack.enter_context(TemporaryFile()) + for library in libraries_info + } + uploads: dict[LibraryInfo, MarcUploadManager] = { + library: stack.enter_context( + MarcUploadManager( + storage_service, + collection_name, + library.library_short_name, + start_time, + library.last_updated if delta else None, + context_parsed.get(library.library_id), + ) ) + for library in libraries_info + } - # Sync the upload_manager to ensure that all the data is written to storage. - upload_manager.sync() + min_last_updated = ( + min([l.last_updated for l in libraries_info if l.last_updated]) + if delta + else None + ) - if len(works) != batch_size: - # We have finished generating MARC records. Cleanup and exit. - with task.transaction() as session: - collection = MarcExporter.collection(session, collection_id) - collection_name = collection.name if collection else "unknown" - completed_uploads = upload_manager.complete() - MarcExporter.create_marc_upload_records( + no_more_works = False + while not all( + [ + file.tell() > storage_service.MINIMUM_MULTIPART_UPLOAD_SIZE + for file in files.values() + ] + ): + works = MarcExporter.query_works( session, - start_time, collection_id, - libraries_info, - completed_uploads, + batch_size=batch_size, + work_id_offset=last_work_id, + last_updated=min_last_updated, ) - upload_manager.remove_session() - task.log.info( - f"Finished generating MARC records for collection '{collection_name}' ({collection_id})." - ) - return + if not works: + no_more_works = True + break + + # Set this for the next iteration + last_work_id = works[-1].id + + works_with_pools = [ + (work, pool) + for work in works + if (pool := work.active_license_pool()) is not None + ] + + # Find ISBN for any work that needs it + isbns = RecursiveEquivalencyCache.equivalent_identifiers( + session, + {pool.identifier for work, pool in works_with_pools}, + Identifier.ISBN, + ) + + for work, pool in works_with_pools: + isbn_identifier = isbns.get(pool.identifier) + records = MarcExporter.process_work( + work, pool, isbn_identifier, libraries_info, base_url, delta + ) + for library, record in records.items(): + files[library].write(record) + + # Upload part to s3, if there is anything to upload + for library, tmp_file in files.items(): + upload = uploads[library] + if not upload.upload_part(tmp_file): + task.log.warning( + f"No data to upload to s3 '{upload.context.s3_key}'." + ) + + if no_more_works: + # Task is complete. Finalize the s3 uploads and create MarcFile records in DB. + for library, upload in uploads.items(): + if upload.complete(): + create( + session, + MarcFile, + id=upload.context.upload_uuid, + library_id=library.library_id, + collection_id=collection_id, + created=start_time, + key=upload.context.s3_key, + since=library.last_updated if delta else None, + ) + task.log.info(f"Completed upload for '{upload.context.s3_key}'") + else: + task.log.warning( + f"No upload for '{upload.context.s3_key}', " + f"because there were no records." + ) + + task.log.info( + f"Finished generating MARC records for collection '{collection_name}' ({collection_id}) " + f"in {(utc_now() - start_time).seconds} seconds." + ) + return # This task is complete, but there are more works waiting to be exported. So we requeue ourselves # to process the next batch. raise task.replace( marc_export_collection.s( collection_id=collection_id, + collection_name=collection_name, start_time=start_time, libraries=[l.model_dump() for l in libraries_info], + context={ + l.library_id: uploads[l].context.model_dump() for l in libraries_info + }, + last_work_id=last_work_id, batch_size=batch_size, - last_work_id=works[-1].id, - update_number=upload_manager.update_number, + delta=delta, ) ) diff --git a/src/palace/manager/marc/annotator.py b/src/palace/manager/marc/annotator.py index 47446955f..ea1be3b8c 100644 --- a/src/palace/manager/marc/annotator.py +++ b/src/palace/manager/marc/annotator.py @@ -5,7 +5,6 @@ from collections.abc import Mapping, Sequence from pymarc import Field, Indicators, Record, Subfield -from sqlalchemy.orm import Session from palace.manager.core.classifier import Classifier from palace.manager.sqlalchemy.model.edition import Edition @@ -44,13 +43,15 @@ class Annotator(LoggerMixin): } @classmethod - def marc_record(cls, work: Work, license_pool: LicensePool) -> Record: + def marc_record( + cls, work: Work, isbn_identifier: Identifier | None, license_pool: LicensePool + ) -> Record: edition = license_pool.presentation_edition identifier = license_pool.identifier record = cls._record() cls.add_control_fields(record, identifier, license_pool, edition) - cls.add_isbn(record, identifier) + cls.add_isbn(record, isbn_identifier) # TODO: The 240 and 130 fields are for translated works, so they can be grouped even # though they have different titles. We do not group editions of the same work in @@ -82,6 +83,7 @@ def library_marc_record( organization_code: str | None, include_summary: bool, include_genres: bool, + delta: bool, ) -> Record: record = cls._copy_record(record) @@ -107,6 +109,9 @@ def library_marc_record( web_client_urls, ) + if delta: + cls.set_revised(record) + return record @classmethod @@ -201,28 +206,15 @@ def add_marc_organization_code(cls, record: Record, marc_org: str) -> None: record.add_field(Field(tag="003", data=marc_org)) @classmethod - def add_isbn(cls, record: Record, identifier: Identifier) -> None: + def add_isbn(cls, record: Record, identifier: Identifier | None) -> None: # Add the ISBN if we have one. - isbn = None - if identifier.type == Identifier.ISBN: - isbn = identifier - if not isbn: - _db = Session.object_session(identifier) - identifier_ids = identifier.equivalent_identifier_ids()[identifier.id] - isbn = ( - _db.query(Identifier) - .filter(Identifier.type == Identifier.ISBN) - .filter(Identifier.id.in_(identifier_ids)) - .order_by(Identifier.id) - .first() - ) - if isbn and isbn.identifier: + if identifier and identifier.identifier: record.add_field( Field( tag="020", indicators=Indicators(" ", " "), subfields=[ - Subfield(code="a", value=isbn.identifier), + Subfield(code="a", value=identifier.identifier), ], ) ) diff --git a/src/palace/manager/marc/exporter.py b/src/palace/manager/marc/exporter.py index 5b2ead5e8..3d6b79835 100644 --- a/src/palace/manager/marc/exporter.py +++ b/src/palace/manager/marc/exporter.py @@ -2,12 +2,10 @@ import datetime from collections.abc import Generator, Iterable, Sequence -from uuid import UUID, uuid4 -import pytz from pydantic import BaseModel, ConfigDict from sqlalchemy import select -from sqlalchemy.orm import Session, aliased +from sqlalchemy.orm import Session, aliased, raiseload, selectinload from palace.manager.integration.base import HasLibraryIntegrationConfiguration from palace.manager.integration.goals import Goals @@ -16,26 +14,29 @@ MarcExporterLibrarySettings, MarcExporterSettings, ) -from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.integration_registry.catalog_services import ( CatalogServicesRegistry, ) from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.contributor import Contribution from palace.manager.sqlalchemy.model.discovery_service_registration import ( DiscoveryServiceRegistration, ) +from palace.manager.sqlalchemy.model.edition import Edition +from palace.manager.sqlalchemy.model.identifier import Identifier from palace.manager.sqlalchemy.model.integration import ( IntegrationConfiguration, IntegrationLibraryConfiguration, ) from palace.manager.sqlalchemy.model.library import Library -from palace.manager.sqlalchemy.model.licensing import LicensePool +from palace.manager.sqlalchemy.model.licensing import ( + LicensePool, + LicensePoolDeliveryMechanism, +) from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.sqlalchemy.model.work import Work -from palace.manager.sqlalchemy.util import create +from palace.manager.sqlalchemy.model.work import Work, WorkGenre from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.log import LoggerMixin -from palace.manager.util.uuid import uuid_encode class LibraryInfo(BaseModel): @@ -48,11 +49,6 @@ class LibraryInfo(BaseModel): include_genres: bool web_client_urls: tuple[str, ...] - s3_key_full_uuid: str - s3_key_full: str - - s3_key_delta_uuid: str - s3_key_delta: str | None = None model_config = ConfigDict(frozen=True) @@ -84,35 +80,6 @@ def settings_class(cls) -> type[MarcExporterSettings]: def library_settings_class(cls) -> type[MarcExporterLibrarySettings]: return MarcExporterLibrarySettings - @staticmethod - def _s3_key( - library: Library, - collection: Collection, - creation_time: datetime.datetime, - uuid: UUID, - since_time: datetime.datetime | None = None, - ) -> str: - """The path to the hosted MARC file for the given library, collection, - and date range.""" - - def date_to_string(date: datetime.datetime) -> str: - return date.astimezone(pytz.UTC).strftime("%Y-%m-%d") - - root = "marc" - short_name = str(library.short_name) - creation = date_to_string(creation_time) - - if since_time: - file_type = f"delta.{date_to_string(since_time)}.{creation}" - else: - file_type = f"full.{creation}" - - uuid_encoded = uuid_encode(uuid) - collection_name = collection.name.replace(" ", "_") - filename = f"{collection_name}.{file_type}.{uuid_encoded}.mrc" - parts = [root, short_name, filename] - return "/".join(parts) - @staticmethod def _needs_update( last_updated_time: datetime.datetime | None, update_frequency: int @@ -206,10 +173,15 @@ def enabled_collections( @classmethod def enabled_libraries( - cls, session: Session, registry: CatalogServicesRegistry, collection_id: int + cls, + session: Session, + registry: CatalogServicesRegistry, + collection_id: int | None, ) -> Sequence[LibraryInfo]: + if collection_id is None: + return [] + library_info = [] - creation_time = utc_now() for collection, library_integration in cls._enabled_collections_and_libraries( session, registry, collection_id ): @@ -230,25 +202,6 @@ def enabled_libraries( web_client_urls = cls._web_client_urls( session, library, library_settings.web_client_url ) - s3_key_full_uuid = uuid4() - s3_key_full = cls._s3_key( - library, - collection, - creation_time, - s3_key_full_uuid, - ) - s3_key_delta_uuid = uuid4() - s3_key_delta = ( - cls._s3_key( - library, - collection, - creation_time, - s3_key_delta_uuid, - since_time=last_updated_time, - ) - if last_updated_time - else None - ) library_info.append( LibraryInfo( library_id=library_id, @@ -259,10 +212,6 @@ def enabled_libraries( include_summary=library_settings.include_summary, include_genres=library_settings.include_genres, web_client_urls=web_client_urls, - s3_key_full_uuid=str(s3_key_full_uuid), - s3_key_full=s3_key_full, - s3_key_delta_uuid=str(s3_key_delta_uuid), - s3_key_delta=s3_key_delta, ) ) library_info.sort(key=lambda info: info.library_id) @@ -271,10 +220,14 @@ def enabled_libraries( @staticmethod def query_works( session: Session, - collection_id: int, - work_id_offset: int | None, + collection_id: int | None, batch_size: int, + work_id_offset: int | None = None, + last_updated: datetime.datetime | None = None, ) -> list[Work]: + if collection_id is None: + return [] + query = ( select(Work) .join(LicensePool) @@ -283,12 +236,38 @@ def query_works( ) .limit(batch_size) .order_by(Work.id.asc()) + .options( + # We set loader options on all the collection properties + # needed to generate the MARC records, so that we don't end + # up doing queries for each work. + selectinload(Work.license_pools).options( + selectinload(LicensePool.identifier), + selectinload(LicensePool.presentation_edition).options( + selectinload(Edition.contributions).options( + selectinload(Contribution.contributor) + ) + ), + selectinload(LicensePool.delivery_mechanisms).options( + selectinload(LicensePoolDeliveryMechanism.delivery_mechanism) + ), + selectinload(LicensePool.data_source), + ), + selectinload(Work.work_genres).options(selectinload(WorkGenre.genre)), + # We set raiseload on all the other properties, so we quickly know if + # a change causes us to start having to issue queries to get a property. + # This will raise a InvalidRequestError, that should fail our tests, so + # we know to add the new required properties to this function. + raiseload("*"), + ) ) - if work_id_offset is not None: + if last_updated: + query = query.where(Work.last_update_time > last_updated) + + if work_id_offset: query = query.where(Work.id > work_id_offset) - return session.execute(query).scalars().unique().all() + return session.execute(query).scalars().all() @staticmethod def collection(session: Session, collection_id: int) -> Collection | None: @@ -296,79 +275,38 @@ def collection(session: Session, collection_id: int) -> Collection | None: select(Collection).where(Collection.id == collection_id) ).scalar_one_or_none() - @classmethod + @staticmethod def process_work( - cls, work: Work, + license_pool: LicensePool, + isbn_identifier: Identifier | None, libraries_info: Iterable[LibraryInfo], base_url: str, + delta: bool, *, - upload_manager: MarcUploadManager, annotator: type[Annotator] = Annotator, - ) -> None: - pool = work.active_license_pool() - if pool is None: - return - base_record = annotator.marc_record(work, pool) - - for library_info in libraries_info: - library_record = annotator.library_marc_record( + ) -> dict[LibraryInfo, bytes]: + base_record = annotator.marc_record(work, isbn_identifier, license_pool) + return { + library_info: annotator.library_marc_record( base_record, - pool.identifier, + license_pool.identifier, base_url, library_info.library_short_name, library_info.web_client_urls, library_info.organization_code, library_info.include_summary, library_info.include_genres, - ) - - upload_manager.add_record( - library_info.s3_key_full, - library_record.as_marc(), - ) - - if ( - library_info.last_updated - and library_info.s3_key_delta - and work.last_update_time + delta, + ).as_marc() + for library_info in libraries_info + if not delta + or ( + work.last_update_time + and library_info.last_updated and work.last_update_time > library_info.last_updated - ): - upload_manager.add_record( - library_info.s3_key_delta, - annotator.set_revised(library_record).as_marc(), - ) - - @staticmethod - def create_marc_upload_records( - session: Session, - start_time: datetime.datetime, - collection_id: int, - libraries_info: Iterable[LibraryInfo], - uploaded_keys: set[str], - ) -> None: - for library_info in libraries_info: - if library_info.s3_key_full in uploaded_keys: - create( - session, - MarcFile, - id=library_info.s3_key_full_uuid, - library_id=library_info.library_id, - collection_id=collection_id, - created=start_time, - key=library_info.s3_key_full, - ) - if library_info.s3_key_delta and library_info.s3_key_delta in uploaded_keys: - create( - session, - MarcFile, - id=library_info.s3_key_delta_uuid, - library_id=library_info.library_id, - collection_id=collection_id, - created=start_time, - since=library_info.last_updated, - key=library_info.s3_key_delta, - ) + ) + } @staticmethod def files_for_cleanup( diff --git a/src/palace/manager/marc/uploader.py b/src/palace/manager/marc/uploader.py index ff43a4951..f6a19e0c7 100644 --- a/src/palace/manager/marc/uploader.py +++ b/src/palace/manager/marc/uploader.py @@ -1,148 +1,176 @@ -from collections import defaultdict -from collections.abc import Generator, Sequence -from contextlib import contextmanager +import datetime +import uuid +from types import TracebackType +from typing import IO, Literal -from celery.exceptions import Ignore, Retry +from pydantic import BaseModel from typing_extensions import Self -from palace.manager.service.redis.models.marc import MarcFileUploadSession +from palace.manager.core.exceptions import BasePalaceException from palace.manager.service.storage.s3 import MultipartS3UploadPart, S3Service from palace.manager.sqlalchemy.model.resource import Representation from palace.manager.util.log import LoggerMixin +from palace.manager.util.uuid import uuid_encode + + +class UploadContext(BaseModel): + upload_uuid: uuid.UUID + s3_key: str + upload_id: str | None = None + parts: list[MultipartS3UploadPart] = [] + + +class MarcUploadException(BasePalaceException): ... class MarcUploadManager(LoggerMixin): """ This class is used to manage the upload of MARC files to S3. The upload is done in multiple - parts, so that the Celery task can be broken up into multiple steps, saving the progress - between steps to redis, and flushing them to S3 when the buffer is large enough. - - This class orchestrates the upload process, delegating the redis operation to the - `MarcFileUploadSession` class, and the S3 upload to the `S3Service` class. + parts, so that the Celery task can be broken up into multiple steps. """ def __init__( - self, storage_service: S3Service, upload_session: MarcFileUploadSession + self, + storage_service: S3Service, + collection_name: str, + library_short_name: str, + creation_time: datetime.datetime, + since_time: datetime.datetime | None, + context: UploadContext | None = None, ): self.storage_service = storage_service - self.upload_session = upload_session - self._buffers: defaultdict[str, str] = defaultdict(str) - self._locked = False - - @property - def locked(self) -> bool: - return self._locked - - @property - def update_number(self) -> int: - return self.upload_session.update_number + self._in_context_manager = False + self._finalized = False + + if context is None: + upload_uuid = uuid.uuid4() + s3_key = self._s3_key( + library_short_name, + collection_name, + creation_time, + upload_uuid, + since_time, + ) + context = UploadContext( + upload_uuid=upload_uuid, + s3_key=s3_key, + ) + self.context = context + + @staticmethod + def _s3_key( + library_short_name: str, + collection_name: str, + creation_time: datetime.datetime, + upload_uuid: uuid.UUID, + since_time: datetime.datetime | None = None, + ) -> str: + """The path to the hosted MARC file for the given library, collection, + and date range.""" + + def date_to_string(date: datetime.datetime) -> str: + return date.astimezone(datetime.timezone.utc).strftime("%Y-%m-%d") + + root = "marc" + creation = date_to_string(creation_time) + + if since_time: + file_type = f"delta.{date_to_string(since_time)}.{creation}" + else: + file_type = f"full.{creation}" + + uuid_encoded = uuid_encode(upload_uuid) + collection_name = collection_name.replace(" ", "_") + filename = f"{collection_name}.{file_type}.{uuid_encoded}.mrc" + parts = [root, library_short_name, filename] + return "/".join(parts) + + def __enter__(self) -> Self: + if self._in_context_manager: + raise MarcUploadException(f"Cannot nest {self.__class__.__name__}.") + self._in_context_manager = True + return self + + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> Literal[False]: + if excinst is not None and not self._finalized: + self.log.error( + "An exception occurred during upload of MARC files. Cancelling in progress upload.", + ) + try: + self.abort() + except Exception as e: + # We log and keep going, since this was already triggered by an exception. + self.log.exception( + f"Failed to abort upload {self.context.s3_key} (UploadID: {self.context.upload_id}) due to exception ({e})." + ) - def add_record(self, key: str, record: bytes) -> None: - self._buffers[key] += record.decode() + self._in_context_manager = False + return False - def _s3_upload_part(self, key: str, upload_id: str) -> MultipartS3UploadPart: - part_number, data = self.upload_session.get_part_num_and_buffer(key) - upload_part = self.storage_service.multipart_upload( - key, upload_id, part_number, data.encode() + def begin_upload(self) -> str: + upload_id = self.storage_service.multipart_create( + self.context.s3_key, content_type=Representation.MARC_MEDIA_TYPE ) - self.upload_session.add_part_and_clear_buffer(key, upload_part) - return upload_part - - def _s3_sync(self, needs_upload: Sequence[str]) -> None: - upload_ids = self.upload_session.get_upload_ids(needs_upload) - for key in needs_upload: - if upload_ids.get(key) is None: - upload_id = self.storage_service.multipart_create( - key, content_type=Representation.MARC_MEDIA_TYPE - ) - self.upload_session.set_upload_id(key, upload_id) - upload_ids[key] = upload_id + self.context.upload_id = upload_id + return upload_id - self._s3_upload_part(key, upload_ids[key]) + def upload_part(self, data: IO[bytes] | bytes) -> bool: + if self._finalized: + raise MarcUploadException("Upload is already finalized.") - def _sync_buffers_to_redis(self) -> dict[str, int]: - buffer_lengths = self.upload_session.append_buffers(self._buffers) - self._buffers.clear() - return buffer_lengths + if isinstance(data, bytes): + length = len(data) + else: + length = data.tell() + data.seek(0) - def sync(self) -> None: - # First sync our buffers to redis - buffer_lengths = self._sync_buffers_to_redis() + if length == 0: + return False - # Then, if any of our redis buffers are large enough sync them to S3. - needs_upload = [ - key - for key, length in buffer_lengths.items() - if length > self.storage_service.MINIMUM_MULTIPART_UPLOAD_SIZE - ] + if self.context.upload_id is None: + upload_id = self.begin_upload() + else: + upload_id = self.context.upload_id - if not needs_upload: - return + part_number = len(self.context.parts) + 1 + upload_part = self.storage_service.multipart_upload( + self.context.s3_key, upload_id, part_number, data + ) + self.context.parts.append(upload_part) + return True - self._s3_sync(needs_upload) + def complete(self) -> bool: + if self._finalized: + raise MarcUploadException("Upload is already finalized.") - def _abort(self) -> None: - in_progress = self.upload_session.get() - for key, upload in in_progress.items(): - if upload.upload_id is None: - # This upload has not started, so there is nothing to abort. - continue - try: - self.storage_service.multipart_abort(key, upload.upload_id) - except Exception as e: - # We log and keep going, since we want to abort as many uploads as possible - # even if some fail, this is likely already being called in an exception handler. - # So we want to do as much cleanup as possible. - self.log.exception( - f"Failed to abort upload {key} (UploadID: {upload.upload_id}) due to exception ({e})." - ) + if self.context.upload_id is None or not self.context.parts: + self.abort() + return False - # Delete our in-progress uploads from redis as well - self.remove_session() + self.storage_service.multipart_complete( + self.context.s3_key, self.context.upload_id, self.context.parts + ) + self._finalized = True + return True - def complete(self) -> set[str]: - # Ensure any local data is synced to Redis. - self._sync_buffers_to_redis() + def abort(self) -> None: + if self._finalized: + return - in_progress = self.upload_session.get() - for key, upload in in_progress.items(): - if upload.upload_id is None: - # The multipart upload hasn't started. Perform a regular S3 upload since all data is in the buffer. - self.storage_service.store( - key, upload.buffer, Representation.MARC_MEDIA_TYPE - ) - else: - if upload.buffer != "": - # Upload the last chunk if the buffer is not empty. The final part has no minimum size requirement. - last_part = self._s3_upload_part(key, upload.upload_id) - upload.parts.append(last_part) - - # Complete the multipart upload. - self.storage_service.multipart_complete( - key, upload.upload_id, upload.parts - ) + if self.context.upload_id is None: + self._finalized = True + return - # Delete the in-progress uploads data from Redis. - if in_progress: - self.upload_session.clear_uploads() - - # Return the keys that were uploaded. - return set(in_progress.keys()) - - def remove_session(self) -> None: - self.upload_session.delete() - - @contextmanager - def begin(self) -> Generator[Self, None, None]: - self._locked = self.upload_session.acquire() - try: - yield self - except Exception as e: - # We want to ignore any celery exceptions that are expected, but - # handle cleanup for any other cases. - if not isinstance(e, (Retry, Ignore)): - self._abort() - raise - finally: - self.upload_session.release() - self._locked = False + self.storage_service.multipart_abort( + self.context.s3_key, self.context.upload_id + ) + self._finalized = True + + @property + def finalized(self) -> bool: + return self._finalized diff --git a/src/palace/manager/service/redis/escape.py b/src/palace/manager/service/redis/escape.py deleted file mode 100644 index 420d81cbb..000000000 --- a/src/palace/manager/service/redis/escape.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -import json -from functools import cached_property - -from palace.manager.core.exceptions import PalaceValueError - - -class JsonPathEscapeMixin: - r""" - Mixin to provide methods for escaping and unescaping JsonPaths for use in Redis / ElastiCache. - - This is necessary because some characters in object keys are not handled well by AWS ElastiCache, - and other characters seem problematic in Redis. - - This mixin provides methods to escape and unescape these characters, so that they can be used in - object keys, and the keys can be queried via JSONPath without issue. - - In ElastiCache when ~ is used in a key, the key is never updated, despite returning a success. And - when a / is used in a key, the key is interpreted as a nested path, nesting a new key for every - slash in the path. This is not the behavior we want, so we need to escape these characters. - - In Redis, the \ character is used as an escape character, and the " character is used to denote - the end of a string for the JSONPath. This means that these characters need to be escaped as well. - - Characters are escaped by prefixing them with a backtick character, followed by a single character - from _MAPPING that represents the escaped character. The backtick character itself is escaped by - prefixing it with another backtick character. - """ - - _ESCAPE_CHAR = "`" - - _MAPPING = { - "/": "s", - "\\": "b", - '"': "'", - "~": "t", - } - - @cached_property - def _FORWARD_MAPPING(self) -> dict[str, str]: - mapping = {k: "".join((self._ESCAPE_CHAR, v)) for k, v in self._MAPPING.items()} - mapping[self._ESCAPE_CHAR] = "".join((self._ESCAPE_CHAR, self._ESCAPE_CHAR)) - return mapping - - @cached_property - def _REVERSE_MAPPING(self) -> dict[str, str]: - mapping = {v: k for k, v in self._MAPPING.items()} - mapping[self._ESCAPE_CHAR] = self._ESCAPE_CHAR - return mapping - - def _escape_path(self, path: str, elasticache: bool = False) -> str: - escaped = "".join([self._FORWARD_MAPPING.get(c, c) for c in path]) - if elasticache: - # As well as the simple escaping we have defined here, for ElastiCache we need to fully - # escape the path as if it were a JSON string. So we call json.dumps to do this. We - # strip the leading and trailing quotes from the result, as we only want the escaped - # string, not the quotes. - escaped = json.dumps(escaped)[1:-1] - return escaped - - def _unescape_path(self, path: str) -> str: - in_escape = False - unescaped = [] - for char in path: - if in_escape: - if char not in self._REVERSE_MAPPING: - raise PalaceValueError( - f"Invalid escape sequence '{self._ESCAPE_CHAR}{char}'" - ) - unescaped.append(self._REVERSE_MAPPING[char]) - in_escape = False - elif char == self._ESCAPE_CHAR: - in_escape = True - else: - unescaped.append(char) - - if in_escape: - raise PalaceValueError("Unterminated escape sequence.") - - return "".join(unescaped) diff --git a/src/palace/manager/service/redis/models/lock.py b/src/palace/manager/service/redis/models/lock.py index 1a6eba9ee..ef4c348af 100644 --- a/src/palace/manager/service/redis/models/lock.py +++ b/src/palace/manager/service/redis/models/lock.py @@ -1,16 +1,13 @@ -import json import random import time from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator, Sequence from contextlib import contextmanager from datetime import timedelta from functools import cached_property -from typing import Any, TypeVar, cast +from typing import cast from uuid import uuid4 -from redis.exceptions import ResponseError - from palace.manager.celery.task import Task from palace.manager.core.exceptions import BasePalaceException from palace.manager.service.redis.redis import Redis @@ -235,213 +232,3 @@ def __init__( else: name = [lock_name] super().__init__(redis_client, name, random_value, lock_timeout, retry_delay) - - -class RedisJsonLock(BaseRedisLock, ABC): - _GET_LOCK_FUNCTION = """ - local function get_lock_value(key, json_key) - local value = redis.call("json.get", key, json_key) - if not value then - return nil - end - return cjson.decode(value)[1] - end - """ - - _ACQUIRE_SCRIPT = f""" - {_GET_LOCK_FUNCTION} - -- If the locks json object doesn't exist, create it with the initial value - redis.call("json.set", KEYS[1], "$", ARGV[4], "nx") - - -- Get the current lock value - local lock_value = get_lock_value(KEYS[1], ARGV[1]) - if not lock_value then - -- The lock isn't currently locked, so we lock it and set the timeout - redis.call("json.set", KEYS[1], ARGV[1], cjson.encode(ARGV[2])) - redis.call("pexpire", KEYS[1], ARGV[3]) - return 1 - elseif lock_value == ARGV[2] then - -- The lock is already held by us, so we extend the timeout - redis.call("pexpire", KEYS[1], ARGV[3]) - return 2 - else - -- The lock is held by someone else, we do nothing - return nil - end - """ - - _RELEASE_SCRIPT = f""" - {_GET_LOCK_FUNCTION} - if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then - redis.call("json.del", KEYS[1], ARGV[1]) - return 1 - else - return nil - end - """ - - _EXTEND_SCRIPT = f""" - {_GET_LOCK_FUNCTION} - if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then - redis.call("pexpire", KEYS[1], ARGV[3]) - return 1 - else - return nil - end - """ - - _DELETE_SCRIPT = f""" - {_GET_LOCK_FUNCTION} - if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then - redis.call("del", KEYS[1]) - return 1 - else - return nil - end - """ - - def __init__( - self, - redis_client: Redis, - random_value: str | None = None, - ): - super().__init__(redis_client, random_value) - - # Register our scripts - self._acquire_script = self._redis_client.register_script(self._ACQUIRE_SCRIPT) - self._release_script = self._redis_client.register_script(self._RELEASE_SCRIPT) - self._extend_script = self._redis_client.register_script(self._EXTEND_SCRIPT) - self._delete_script = self._redis_client.register_script(self._DELETE_SCRIPT) - - @property - @abstractmethod - def _lock_timeout_ms(self) -> int: - """ - The lock timeout in milliseconds. - """ - ... - - @property - def _lock_json_key(self) -> str: - """ - The key to use for the lock value in the JSON object. - - This can be overridden if you need to store the lock value in a different key. It should - be a Redis JSONPath. - See: https://redis.io/docs/latest/develop/data-types/json/path/ - """ - return "$.lock" - - @property - def _initial_value(self) -> str: - """ - The initial value to use for the locks JSON object. - """ - return json.dumps({}) - - T = TypeVar("T") - - @classmethod - def _parse_multi( - cls, value: Mapping[str, Sequence[T]] | None - ) -> dict[str, T | None]: - """ - Helper function that makes it easier to work with the results of a JSON GET command, - where you request multiple keys. - """ - if value is None: - return {} - return {k: cls._parse_value(v) for k, v in value.items()} - - @staticmethod - def _parse_value(value: Sequence[T] | None) -> T | None: - """ - Helper function to parse the value from the results of a JSON GET command, where you - expect the JSONPath to return a single value. - """ - if value is None: - return None - try: - return value[0] - except IndexError: - return None - - @classmethod - def _parse_value_or_raise(cls, value: Sequence[T] | None) -> T: - """ - Wrapper around _parse_value that raises an exception if the value is None. - """ - parsed_value = cls._parse_value(value) - if parsed_value is None: - raise LockError(f"Could not parse value ({json.dumps(value)})") - return parsed_value - - @staticmethod - def _validate_pipeline_results(results: list[Any]) -> bool: - """ - This function validates that all the results of the pipeline are successful, - and not a ResponseError. - - NOTE: The AWS ElastiCache implementation returns slightly different results than Redis. - In Redis, unsuccessful results when a key is not found are `None`, but in AWS they are - returned as a `ResponseError`, so we are careful to check for both in this function. - """ - return all(r and not isinstance(r, ResponseError) for r in results) - - def acquire(self) -> bool: - return ( - self._acquire_script( - keys=(self.key,), - args=( - self._lock_json_key, - self._random_value, - self._lock_timeout_ms, - self._initial_value, - ), - ) - is not None - ) - - def release(self) -> bool: - """ - Release the lock. - - You must have the lock to release it. This will unset the lock value in the JSON object, but importantly - it will not delete the JSON object itself. If you want to delete the JSON object, use the delete method. - """ - return ( - self._release_script( - keys=(self.key,), - args=(self._lock_json_key, self._random_value), - ) - is not None - ) - - def locked(self, by_us: bool = False) -> bool: - lock_value: str | None = self._parse_value( - self._redis_client.json().get(self.key, self._lock_json_key) - ) - if by_us: - return lock_value == self._random_value - return lock_value is not None - - def extend_timeout(self) -> bool: - return ( - self._extend_script( - keys=(self.key,), - args=(self._lock_json_key, self._random_value, self._lock_timeout_ms), - ) - is not None - ) - - def delete(self) -> bool: - """ - Delete the whole json object, including the lock. Must have the lock to delete the object. - """ - return ( - self._delete_script( - keys=(self.key,), - args=(self._lock_json_key, self._random_value), - ) - is not None - ) diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py deleted file mode 100644 index 6b881e4ac..000000000 --- a/src/palace/manager/service/redis/models/marc.py +++ /dev/null @@ -1,319 +0,0 @@ -from __future__ import annotations - -import json -import sys -from collections.abc import Callable, Generator, Mapping, Sequence -from contextlib import contextmanager -from enum import auto -from functools import cached_property -from typing import Any - -from pydantic import BaseModel -from redis import WatchError - -from palace.manager.service.redis.escape import JsonPathEscapeMixin -from palace.manager.service.redis.models.lock import LockError, RedisJsonLock -from palace.manager.service.redis.redis import Pipeline, Redis -from palace.manager.service.storage.s3 import MultipartS3UploadPart -from palace.manager.sqlalchemy.model.collection import Collection -from palace.manager.util.log import LoggerMixin - -# TODO: Remove this when we drop support for Python 3.10 -if sys.version_info >= (3, 11): - from enum import StrEnum -else: - from backports.strenum import StrEnum - - -class MarcFileUploadSessionError(LockError): - pass - - -class MarcFileUpload(BaseModel): - buffer: str = "" - upload_id: str | None = None - parts: list[MultipartS3UploadPart] = [] - - -class MarcFileUploadState(StrEnum): - INITIAL = auto() - QUEUED = auto() - UPLOADING = auto() - - -class MarcFileUploadSession(RedisJsonLock, JsonPathEscapeMixin, LoggerMixin): - """ - This class is used as a lock for the Celery MARC export task, to ensure that only one - task can upload MARC files for a given collection at a time. It increments an update - number each time an update is made, to guard against corruption if a task gets run - twice. - - It stores the intermediate results of the MARC file generation process, so that the task - can complete in multiple steps, saving the progress between steps to redis, and flushing - them to S3 when the buffer is full. - - This object is focused on the redis part of this operation, the actual s3 upload orchestration - is handled by the `MarcUploadManager` class. - """ - - def __init__( - self, - redis_client: Redis, - collection_id: int, - update_number: int = 0, - ): - super().__init__(redis_client) - self._collection_id = collection_id - self._update_number = update_number - - @cached_property - def key(self) -> str: - return self._redis_client.get_key( - self.__class__.__name__, - Collection.redis_key_from_id(self._collection_id), - ) - - @property - def _lock_timeout_ms(self) -> int: - return 20 * 60 * 1000 # 20 minutes - - @property - def update_number(self) -> int: - return self._update_number - - @property - def _initial_value(self) -> str: - """ - The initial value to use for the locks JSON object. - """ - return json.dumps( - {"uploads": {}, "update_number": 0, "state": MarcFileUploadState.INITIAL} - ) - - @property - def _update_number_json_key(self) -> str: - return "$.update_number" - - @property - def _uploads_json_key(self) -> str: - return "$.uploads" - - @property - def _state_json_key(self) -> str: - return "$.state" - - @staticmethod - def _upload_initial_value(buffer_data: str) -> dict[str, Any]: - return MarcFileUpload(buffer=buffer_data).model_dump(exclude_none=True) - - def _upload_path(self, upload_key: str) -> str: - upload_key = self._escape_path(upload_key, self._redis_client.elasticache) - return f'{self._uploads_json_key}["{upload_key}"]' - - def _buffer_path(self, upload_key: str) -> str: - upload_path = self._upload_path(upload_key) - return f"{upload_path}.buffer" - - def _upload_id_path(self, upload_key: str) -> str: - upload_path = self._upload_path(upload_key) - return f"{upload_path}.upload_id" - - def _parts_path(self, upload_key: str) -> str: - upload_path = self._upload_path(upload_key) - return f"{upload_path}.parts" - - @contextmanager - def _pipeline(self, begin_transaction: bool = True) -> Generator[Pipeline]: - with self._redis_client.pipeline() as pipe: - pipe.watch(self.key) - fetched_data = self._parse_multi( - pipe.json().get( - self.key, self._lock_json_key, self._update_number_json_key - ) - ) - # Check that we hold the lock - if ( - remote_random := fetched_data.get(self._lock_json_key) - ) != self._random_value: - raise MarcFileUploadSessionError( - f"Must hold lock to update upload session. " - f"Expected: {self._random_value}, got: {remote_random}" - ) - # Check that the update number is correct - if ( - remote_update_number := fetched_data.get(self._update_number_json_key) - ) != self._update_number: - raise MarcFileUploadSessionError( - f"Update number mismatch. " - f"Expected: {self._update_number}, got: {remote_update_number}" - ) - if begin_transaction: - pipe.multi() - yield pipe - - def _execute_pipeline( - self, - pipe: Pipeline, - updates: int, - *, - state: MarcFileUploadState = MarcFileUploadState.UPLOADING, - ) -> list[Any]: - if not pipe.explicit_transaction: - raise MarcFileUploadSessionError( - "Pipeline should be in explicit transaction mode before executing." - ) - pipe.json().set(self.key, path=self._state_json_key, obj=state) - pipe.json().numincrby(self.key, self._update_number_json_key, updates) - pipe.pexpire(self.key, self._lock_timeout_ms) - try: - pipe_results = pipe.execute(raise_on_error=False) - except WatchError as e: - raise MarcFileUploadSessionError( - "Failed to update buffers. Another process is modifying the buffers." - ) from e - self._update_number = self._parse_value_or_raise(pipe_results[-2]) - - return pipe_results[:-3] - - def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: - if not data: - return {} - - set_results = {} - with self._pipeline(begin_transaction=False) as pipe: - existing_uploads: list[str] = self._parse_value_or_raise( - pipe.json().objkeys(self.key, self._uploads_json_key) - ) - existing_uploads = [self._unescape_path(p) for p in existing_uploads] - pipe.multi() - for key, value in data.items(): - if value == "": - continue - if key in existing_uploads: - pipe.json().strappend( - self.key, path=self._buffer_path(key), value=value - ) - else: - pipe.json().set( - self.key, - path=(self._upload_path(key)), - obj=self._upload_initial_value(value), - ) - set_results[key] = len(value) - - pipe_results = self._execute_pipeline(pipe, len(data)) - - if not self._validate_pipeline_results(pipe_results): - raise MarcFileUploadSessionError("Failed to append buffers.") - - return { - k: set_results[k] if v is True else self._parse_value_or_raise(v) - for k, v in zip(data.keys(), pipe_results) - } - - def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> None: - with self._pipeline() as pipe: - pipe.json().arrappend( - self.key, - self._parts_path(key), - part.model_dump(), - ) - pipe.json().set( - self.key, - path=self._buffer_path(key), - obj="", - ) - pipe_results = self._execute_pipeline(pipe, 1) - - if not self._validate_pipeline_results(pipe_results): - raise MarcFileUploadSessionError("Failed to add part and clear buffer.") - - def set_upload_id(self, key: str, upload_id: str) -> None: - with self._pipeline() as pipe: - pipe.json().set( - self.key, - path=self._upload_id_path(key), - obj=upload_id, - nx=True, - ) - pipe_results = self._execute_pipeline(pipe, 1) - - if not self._validate_pipeline_results(pipe_results): - raise MarcFileUploadSessionError("Failed to set upload ID.") - - def clear_uploads(self) -> None: - with self._pipeline() as pipe: - pipe.json().clear(self.key, self._uploads_json_key) - pipe_results = self._execute_pipeline(pipe, 1) - - if not self._validate_pipeline_results(pipe_results): - raise MarcFileUploadSessionError("Failed to clear uploads.") - - def _get_specific( - self, - keys: str | Sequence[str], - get_path: Callable[[str], str], - ) -> dict[str, Any]: - if isinstance(keys, str): - keys = [keys] - paths = {get_path(k): k for k in keys} - results = self._redis_client.json().get(self.key, *paths.keys()) - if len(keys) == 1: - return {keys[0]: self._parse_value(results)} - else: - return {paths[k]: v for k, v in self._parse_multi(results).items()} - - def _get_all(self, key: str) -> dict[str, Any]: - get_results = self._redis_client.json().get(self.key, key) - results: dict[str, Any] | None = self._parse_value(get_results) - - if results is None: - return {} - - return {self._unescape_path(k): v for k, v in results.items()} - - def get(self, keys: str | Sequence[str] | None = None) -> dict[str, MarcFileUpload]: - if keys is None: - uploads = self._get_all(self._uploads_json_key) - else: - uploads = self._get_specific(keys, self._upload_path) - - return { - k: MarcFileUpload.model_validate(v) - for k, v in uploads.items() - if v is not None - } - - def get_upload_ids(self, keys: str | Sequence[str]) -> dict[str, str]: - return self._get_specific(keys, self._upload_id_path) - - def get_part_num_and_buffer(self, key: str) -> tuple[int, str]: - with self._redis_client.pipeline() as pipe: - pipe.json().get(self.key, self._buffer_path(key)) - pipe.json().arrlen(self.key, self._parts_path(key)) - results = pipe.execute(raise_on_error=False) - if not self._validate_pipeline_results(results): - raise MarcFileUploadSessionError( - "Failed to get part number and buffer data." - ) - - buffer_data: str = self._parse_value_or_raise(results[0]) - # AWS S3 requires part numbers to start at 1, so we need to increment by 1. - # - # NOTE: This is not true in MinIO (our local development environment). MinIO - # allows both 0 and 1 as the first part number. Therefore, tests will pass if this is - # changed, but it will fail when running in an actual AWS environment. - part_number: int = self._parse_value_or_raise(results[1]) + 1 - - return part_number, buffer_data - - def state(self) -> MarcFileUploadState | None: - get_results = self._redis_client.json().get(self.key, self._state_json_key) - state: str | None = self._parse_value(get_results) - if state is None: - return None - return MarcFileUploadState(state) - - def set_state(self, state: MarcFileUploadState) -> None: - with self._pipeline() as pipe: - self._execute_pipeline(pipe, 0, state=state) diff --git a/src/palace/manager/service/redis/redis.py b/src/palace/manager/service/redis/redis.py index 802b1127d..d1eb3bbb9 100644 --- a/src/palace/manager/service/redis/redis.py +++ b/src/palace/manager/service/redis/redis.py @@ -97,15 +97,6 @@ def key_args(self, args: list[Any]) -> Sequence[str]: RedisCommandArgs("MGET", args_end=None), RedisCommandArgs("EXISTS", args_end=None), RedisCommandArgs("EXPIRETIME"), - RedisCommandArgs("JSON.CLEAR"), - RedisCommandArgs("JSON.SET"), - RedisCommandArgs("JSON.STRLEN"), - RedisCommandArgs("JSON.STRAPPEND"), - RedisCommandArgs("JSON.NUMINCRBY"), - RedisCommandArgs("JSON.GET"), - RedisCommandArgs("JSON.OBJKEYS"), - RedisCommandArgs("JSON.ARRAPPEND"), - RedisCommandArgs("JSON.ARRLEN"), RedisVariableCommandArgs("EVALSHA", key_index=1), ] } diff --git a/src/palace/manager/service/storage/s3.py b/src/palace/manager/service/storage/s3.py index dcfd91fca..d68941f21 100644 --- a/src/palace/manager/service/storage/s3.py +++ b/src/palace/manager/service/storage/s3.py @@ -5,7 +5,7 @@ from io import BytesIO from string import Formatter from types import TracebackType -from typing import TYPE_CHECKING, BinaryIO +from typing import IO, TYPE_CHECKING, BinaryIO from urllib.parse import quote from botocore.exceptions import BotoCoreError, ClientError @@ -212,7 +212,7 @@ def multipart_create(self, key: str, content_type: str | None = None) -> str: return upload["UploadId"] def multipart_upload( - self, key: str, upload_id: str, part_number: int, content: bytes + self, key: str, upload_id: str, part_number: int, content: bytes | IO[bytes] ) -> MultipartS3UploadPart: self.log.info(f"Uploading part {part_number} of {key} to {self.bucket}") result = self.client.upload_part( diff --git a/src/palace/manager/sqlalchemy/model/identifier.py b/src/palace/manager/sqlalchemy/model/identifier.py index 68eba02c1..ab530ce0a 100644 --- a/src/palace/manager/sqlalchemy/model/identifier.py +++ b/src/palace/manager/sqlalchemy/model/identifier.py @@ -23,7 +23,7 @@ func, ) from sqlalchemy.exc import MultipleResultsFound, NoResultFound -from sqlalchemy.orm import Mapped, joinedload, relationship +from sqlalchemy.orm import Mapped, joinedload, relationship, selectinload from sqlalchemy.orm.session import Session from sqlalchemy.sql import select from sqlalchemy.sql.expression import and_, or_ @@ -1216,3 +1216,54 @@ class RecursiveEquivalencyCache(Base): is_parent = Column(Boolean, Computed("parent_identifier_id = identifier_id")) __table_args__ = (UniqueConstraint(parent_identifier_id, identifier_id),) + + @staticmethod + def equivalent_identifiers( + session: Session, identifiers: set[Identifier], type: str | None = None + ) -> dict[Identifier, Identifier]: + """ + Find all equivalent identifiers for the given Identifiers. + + :param session: DB Session + :param identifiers: Set of Identifiers that we need equivalencies for + :param type: An optional type, if given only equivalent identifiers + of this type will be returned. + :return: A dictionary mapping input identifiers to equivalent identifiers. + """ + + # Find identifiers that don't need to be looked up + results = ( + {i: i for i in identifiers if i.type == type} if type is not None else {} + ) + needs_lookup = {i.id: i for i in identifiers - results.keys()} + if not needs_lookup: + return results + + query = ( + select(RecursiveEquivalencyCache) + .join( + Identifier, + RecursiveEquivalencyCache.identifier_id == Identifier.id, + ) + .where( + RecursiveEquivalencyCache.parent_identifier_id.in_(needs_lookup.keys()), + ) + .order_by( + RecursiveEquivalencyCache.parent_identifier_id, + RecursiveEquivalencyCache.is_parent.desc(), + RecursiveEquivalencyCache.identifier_id.desc(), + ) + .options( + selectinload(RecursiveEquivalencyCache.identifier), + ) + ) + if type is not None: + query = query.where(Identifier.type == Identifier.ISBN) + + equivalents = session.execute(query).scalars().all() + + for equivalent in equivalents: + parent_identifier = needs_lookup[equivalent.parent_identifier_id] + results[parent_identifier] = equivalent.identifier + + return results diff --git a/src/palace/manager/sqlalchemy/model/licensing.py b/src/palace/manager/sqlalchemy/model/licensing.py index 2105e3e17..073d92f6f 100644 --- a/src/palace/manager/sqlalchemy/model/licensing.py +++ b/src/palace/manager/sqlalchemy/model/licensing.py @@ -298,6 +298,8 @@ class LicensePool(Base): # Identifier from a given DataSource. __table_args__ = ( UniqueConstraint("identifier_id", "data_source_id", "collection_id"), + # This index was added to speed up queries for generating MARC XML. + Index("ix_licensepools_collection_id_work_id", collection_id, work_id), ) delivery_mechanisms: Mapped[list[LicensePoolDeliveryMechanism]] = relationship( diff --git a/tests/fixtures/marc.py b/tests/fixtures/marc.py index 8d2bea910..f33977162 100644 --- a/tests/fixtures/marc.py +++ b/tests/fixtures/marc.py @@ -42,8 +42,6 @@ def __init__( self.collection2.libraries = [self.library1] self.collection3.libraries = [self.library2] - self.test_marc_file_key = "test-file-1.mrc" - def integration(self) -> IntegrationConfiguration: return self._db.integration_configuration( MarcExporter, Goals.CATALOG_GOAL, name="MARC Exporter" @@ -54,13 +52,17 @@ def work(self, collection: Collection | None = None) -> Work: edition = self._db.edition() self._db.licensepool(edition, collection=collection) work = self._db.work(presentation_edition=edition) - work.last_update_time = utc_now() + # We set the works last updated time to 1 day ago, so we know this work + # will only be included in delta exports covering a time range before + # 1 day ago. This lets us easily test works being included / excluded + # based on their `last_update_time`. + work.last_update_time = utc_now() - datetime.timedelta(days=1) return work def works(self, collection: Collection | None = None) -> list[Work]: return [self.work(collection) for _ in range(5)] - def configure_export(self, *, marc_file: bool = True) -> None: + def configure_export(self) -> None: marc_integration = self.integration() self._db.integration_library_configuration( marc_integration, @@ -77,12 +79,6 @@ def configure_export(self, *, marc_file: bool = True) -> None: self.collection2.export_marc_records = True self.collection3.export_marc_records = True - if marc_file: - self.marc_file( - key=self.test_marc_file_key, - created=utc_now() - datetime.timedelta(days=7), - ) - def enabled_libraries( self, collection: Collection | None = None ) -> Sequence[LibraryInfo]: diff --git a/tests/fixtures/s3.py b/tests/fixtures/s3.py index 3236b62b8..b25d2435b 100644 --- a/tests/fixtures/s3.py +++ b/tests/fixtures/s3.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Generator from dataclasses import dataclass, field -from typing import TYPE_CHECKING, BinaryIO, NamedTuple, Protocol +from typing import IO, TYPE_CHECKING, BinaryIO, NamedTuple, Protocol from unittest.mock import MagicMock from uuid import uuid4 @@ -73,10 +73,15 @@ def _upload_complete(self) -> None: def _upload_abort(self) -> None: ... -@dataclass class MockMultipartUploadPart: - part_data: MultipartS3UploadPart - content: bytes + def __init__( + self, part_data: MultipartS3UploadPart, content: bytes | IO[bytes] + ) -> None: + self.part_data = part_data + if isinstance(content, bytes): + self.content = content + else: + self.content = content.read() @dataclass @@ -131,7 +136,7 @@ def multipart_create(self, key: str, content_type: str | None = None) -> str: return upload_id def multipart_upload( - self, key: str, upload_id: str, part_number: int, content: bytes + self, key: str, upload_id: str, part_number: int, content: bytes | IO[bytes] ) -> MultipartS3UploadPart: etag = str(uuid4()) if not 1 <= part_number <= 10000: diff --git a/tests/manager/celery/tasks/test_marc.py b/tests/manager/celery/tasks/test_marc.py index 929a4c66a..798d5de3f 100644 --- a/tests/manager/celery/tasks/test_marc.py +++ b/tests/manager/celery/tasks/test_marc.py @@ -1,5 +1,4 @@ import datetime -from typing import Any from unittest.mock import ANY, call, patch import pytest @@ -7,14 +6,11 @@ from sqlalchemy import select from palace.manager.celery.tasks import marc +from palace.manager.celery.tasks.marc import marc_export_collection_lock from palace.manager.marc.exporter import MarcExporter from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.logging.configuration import LogLevel -from palace.manager.service.redis.models.marc import ( - MarcFileUploadSession, - MarcFileUploadSessionError, - MarcFileUploadState, -) +from palace.manager.service.redis.models.lock import RedisLock from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.marcfile import MarcFile from palace.manager.sqlalchemy.model.work import Work @@ -49,6 +45,11 @@ def test_normal_run( celery_fixture: CeleryFixture, ): marc_exporter_fixture.configure_export() + marc_exporter_fixture.marc_file( + collection=marc_exporter_fixture.collection1, + library=marc_exporter_fixture.library1, + created=utc_now() - datetime.timedelta(days=7), + ) with patch.object(marc, "marc_export_collection") as marc_export_collection: # Runs against all the expected collections collections = [ @@ -59,14 +60,43 @@ def test_normal_run( for collection in collections: marc_exporter_fixture.work(collection) marc.marc_export.delay().wait() + + # We make the calls to generate a full export for every collection marc_export_collection.delay.assert_has_calls( [ - call(collection_id=collection.id, start_time=ANY, libraries=ANY) + call( + collection_id=collection.id, + collection_name=collection.name, + start_time=ANY, + libraries=ANY, + ) for collection in collections ], any_order=True, ) + # We make the calls to generate a delta export only for collection1 + marc_export_collection.delay.assert_any_call( + collection_id=marc_exporter_fixture.collection1.id, + collection_name=marc_exporter_fixture.collection1.name, + start_time=ANY, + libraries=ANY, + delta=True, + ) + + # Make sure the call was made with the correct library set + [delta_call] = [ + c + for c in marc_export_collection.delay.mock_calls + if "delta" in c.kwargs + ] + libraries_kwarg = delta_call.kwargs["libraries"] + assert len(libraries_kwarg) == 1 + assert ( + libraries_kwarg[0].get("library_id") + == marc_exporter_fixture.library1.id + ) + def test_skip_collections( self, db: DatabaseTransactionFixture, @@ -75,36 +105,34 @@ def test_skip_collections( celery_fixture: CeleryFixture, ): marc_exporter_fixture.configure_export() - collections = [ - marc_exporter_fixture.collection1, - marc_exporter_fixture.collection2, - marc_exporter_fixture.collection3, - ] - for collection in collections: - marc_exporter_fixture.work(collection) with patch.object(marc, "marc_export_collection") as marc_export_collection: - # Collection 1 should be skipped because it is locked - assert marc_exporter_fixture.collection1.id is not None - MarcFileUploadSession( - redis_fixture.client, marc_exporter_fixture.collection1.id - ).acquire() + # Collection 1 should be skipped because it has no works # Collection 2 should be skipped because it was updated recently + marc_exporter_fixture.work(marc_exporter_fixture.collection2) marc_exporter_fixture.marc_file( - collection=marc_exporter_fixture.collection2 + collection=marc_exporter_fixture.collection2, + library=marc_exporter_fixture.library1, ) - # Collection 3 should be skipped because its state is not INITIAL - assert marc_exporter_fixture.collection3.id is not None - upload_session = MarcFileUploadSession( - redis_fixture.client, marc_exporter_fixture.collection3.id + # Collection 3 should get a full export, but not a delta, because + # its work hasn't been updated since the last full export + work = marc_exporter_fixture.work(marc_exporter_fixture.collection3) + work.last_update_time = utc_now() - datetime.timedelta(days=50) + marc_exporter_fixture.marc_file( + collection=marc_exporter_fixture.collection3, + library=marc_exporter_fixture.library2, + created=utc_now() - datetime.timedelta(days=45), ) - with upload_session.lock() as acquired: - assert acquired - upload_session.set_state(MarcFileUploadState.QUEUED) marc.marc_export.delay().wait() - marc_export_collection.delay.assert_not_called() + + marc_export_collection.delay.assert_called_once_with( + collection_id=marc_exporter_fixture.collection3.id, + collection_name=marc_exporter_fixture.collection3.name, + start_time=ANY, + libraries=ANY, + ) class MarcExportCollectionFixture: @@ -133,17 +161,7 @@ def __init__( self.start_time = utc_now() def marc_files(self) -> list[MarcFile]: - # We need to ignore the test-file-1.mrc file, which is created by our call to configure_export. - return [ - f - for f in self.db.session.execute(select(MarcFile)).scalars().all() - if f.key != self.marc_exporter_fixture.test_marc_file_key - ] - - def redis_data(self, collection: Collection) -> dict[str, Any] | None: - assert collection.id is not None - uploads = MarcFileUploadSession(self.redis_fixture.client, collection.id) - return self.redis_fixture.client.json().get(uploads.key) + return self.db.session.execute(select(MarcFile)).scalars().all() def setup_minio_storage(self) -> None: self.services_fixture.services.storage.override( @@ -156,15 +174,26 @@ def setup_mock_storage(self) -> None: def works(self, collection: Collection) -> list[Work]: return [self.marc_exporter_fixture.work(collection) for _ in range(15)] - def export_collection(self, collection: Collection) -> None: + def export_collection(self, collection: Collection, delta: bool = False) -> None: service = self.services_fixture.services.integration_registry.catalog_services() assert collection.id is not None info = MarcExporter.enabled_libraries(self.db.session, service, collection.id) libraries = [l.model_dump() for l in info] marc.marc_export_collection.delay( - collection.id, batch_size=5, start_time=self.start_time, libraries=libraries + collection.id, + collection_name=collection.name, + batch_size=5, + start_time=self.start_time, + libraries=libraries, + delta=delta, ).wait() + def redis_lock(self, collection: Collection, delta: bool = False) -> RedisLock: + assert collection.id is not None + return marc_export_collection_lock( + self.redis_fixture.client, collection.id, delta=delta + ) + @pytest.fixture def marc_export_collection_fixture( @@ -196,25 +225,23 @@ def test_normal_run( ): marc_export_collection_fixture.setup_minio_storage() collection = marc_exporter_fixture.collection1 - work_uris = [ - work.license_pools[0].identifier.urn - for work in marc_export_collection_fixture.works(collection) - ] + works = marc_export_collection_fixture.works(collection) + work_uris = [work.license_pools[0].identifier.urn for work in works] # Run the full end-to-end process for exporting a collection, this should generate # 3 batches of 5 works each, putting the results into minio. marc_export_collection_fixture.export_collection(collection) - # Verify that we didn't leave anything in the redis cache. - assert marc_export_collection_fixture.redis_data(collection) is None + # Lock is released + assert not marc_export_collection_fixture.redis_lock(collection).locked() # Verify that the expected number of files were uploaded to minio. uploaded_files = s3_service_integration_fixture.list_objects("public") - assert len(uploaded_files) == 3 + assert len(uploaded_files) == 2 # Verify that the expected number of marc files were created in the database. marc_files = marc_export_collection_fixture.marc_files() - assert len(marc_files) == 3 + assert len(marc_files) == 2 filenames = [marc_file.key for marc_file in marc_files] # Verify that the uploaded files are the expected ones. @@ -233,10 +260,51 @@ def test_normal_run( assert all(record["003"].data == expected_org for record in records) # Make sure records have the correct status - expected_status = "c" if "delta" in file else "n" - assert all( - record.leader.record_status == expected_status for record in records + assert all(record.leader.record_status == "n" for record in records) + + # Try running a delta export now + marc_export_collection_fixture.export_collection(collection, delta=True) + + # Because no works have been updated since the last run, no delta exports are generated + marc_files = marc_export_collection_fixture.marc_files() + assert len(marc_files) == 2 + + # Update a couple works last_updated_time + updated_works = [works[0], works[1]] + for work in updated_works: + work.last_update_time = utc_now() + + marc_export_collection_fixture.export_collection(collection, delta=True) + + # Now we generate marc files + marc_files = marc_export_collection_fixture.marc_files() + assert len(marc_files) == 4 + delta_marc_files = [ + marc_file + for marc_file in marc_files + if marc_file.key and "delta" in marc_file.key + ] + assert len(delta_marc_files) == 2 + + # Verify that the marc files contain the expected works. + for marc_file in delta_marc_files: + assert marc_file.key is not None + data = s3_service_integration_fixture.get_object("public", marc_file.key) + records = list(MARCReader(data)) + assert len(records) == 2 + marc_uris = [record["001"].data for record in records] + assert set(marc_uris) == { + work.license_pools[0].identifier.urn for work in updated_works + } + + # Make sure the records have the correct organization code. + expected_org = ( + "library1-org" if "library1" in marc_file.key else "library2-org" ) + assert all(record["003"].data == expected_org for record in records) + + # Make sure records have the correct status + assert all(record.leader.record_status == "c" for record in records) def test_collection_no_works( self, @@ -250,7 +318,7 @@ def test_collection_no_works( assert marc_export_collection_fixture.marc_files() == [] assert s3_service_integration_fixture.list_objects("public") == [] - assert marc_export_collection_fixture.redis_data(collection) is None + assert not marc_export_collection_fixture.redis_lock(collection).locked() def test_exception_handled( self, @@ -266,10 +334,10 @@ def test_exception_handled( with pytest.raises(Exception, match="Test Exception"): marc_export_collection_fixture.export_collection(collection) - # After the exception, we should have aborted the multipart uploads and deleted the redis data. + # After the exception, we should have aborted the multipart uploads and released the lock assert marc_export_collection_fixture.marc_files() == [] - assert marc_export_collection_fixture.redis_data(collection) is None - assert len(marc_export_collection_fixture.mock_s3.aborted) == 3 + assert len(marc_export_collection_fixture.mock_s3.aborted) == 2 + assert not marc_export_collection_fixture.redis_lock(collection).locked() def test_locked( self, @@ -280,41 +348,13 @@ def test_locked( ): caplog.set_level(LogLevel.info) collection = marc_exporter_fixture.collection1 - assert collection.id is not None - MarcFileUploadSession(redis_fixture.client, collection.id).acquire() + marc_export_collection_fixture.redis_lock(collection).acquire() marc_export_collection_fixture.setup_mock_storage() with patch.object(MarcExporter, "query_works") as query: marc_export_collection_fixture.export_collection(collection) query.assert_not_called() assert "another task is already processing it" in caplog.text - def test_outdated_task_run( - self, - redis_fixture: RedisFixture, - marc_exporter_fixture: MarcExporterFixture, - marc_export_collection_fixture: MarcExportCollectionFixture, - caplog: pytest.LogCaptureFixture, - ): - # In the case that an old task is run again for some reason, it should - # detect that its update number is incorrect and exit. - caplog.set_level(LogLevel.info) - collection = marc_exporter_fixture.collection1 - marc_export_collection_fixture.setup_mock_storage() - assert collection.id is not None - - # Acquire the lock and start an upload, this simulates another task having done work - # that the current task doesn't know about. - uploads = MarcFileUploadSession(redis_fixture.client, collection.id) - with uploads.lock() as locked: - assert locked - uploads.append_buffers({"test": "data"}) - - with pytest.raises(MarcFileUploadSessionError, match="Update number mismatch"): - marc_export_collection_fixture.export_collection(collection) - - assert marc_export_collection_fixture.marc_files() == [] - assert marc_export_collection_fixture.redis_data(collection) is None - def test_marc_export_cleanup( db: DatabaseTransactionFixture, @@ -323,7 +363,7 @@ def test_marc_export_cleanup( marc_exporter_fixture: MarcExporterFixture, services_fixture: ServicesFixture, ): - marc_exporter_fixture.configure_export(marc_file=False) + marc_exporter_fixture.configure_export() mock_s3 = s3_service_fixture.mock_service() services_fixture.services.storage.public.override(mock_s3) diff --git a/tests/manager/marc/test_annotator.py b/tests/manager/marc/test_annotator.py index 41d59c125..f55a137df 100644 --- a/tests/manager/marc/test_annotator.py +++ b/tests/manager/marc/test_annotator.py @@ -11,7 +11,6 @@ from palace.manager.marc.annotator import Annotator from palace.manager.sqlalchemy.model.classification import Genre from palace.manager.sqlalchemy.model.contributor import Contributor -from palace.manager.sqlalchemy.model.datasource import DataSource from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.model.identifier import Identifier from palace.manager.sqlalchemy.model.licensing import ( @@ -108,7 +107,7 @@ def test_marc_record( work, pool = annotator_fixture.test_work() annotator = annotator_fixture.annotator - record = annotator.marc_record(work, pool) + record = annotator.marc_record(work, pool.identifier, pool) assert annotator_fixture.record_tags(record) == { 1, 5, @@ -136,7 +135,7 @@ def test_marc_record( def test__copy_record(self, annotator_fixture: AnnotatorFixture): work, pool = annotator_fixture.test_work() annotator = annotator_fixture.annotator - record = annotator.marc_record(work, pool) + record = annotator.marc_record(work, None, pool) copied = annotator_fixture.annotator._copy_record(record) assert copied is not record assert copied.as_marc() == record.as_marc() @@ -144,7 +143,7 @@ def test__copy_record(self, annotator_fixture: AnnotatorFixture): def test_library_marc_record(self, annotator_fixture: AnnotatorFixture): work, pool = annotator_fixture.test_work() annotator = annotator_fixture.annotator - generic_record = annotator.marc_record(work, pool) + generic_record = annotator.marc_record(work, None, pool) library_marc_record = functools.partial( annotator.library_marc_record, @@ -156,6 +155,7 @@ def test_library_marc_record(self, annotator_fixture: AnnotatorFixture): organization_code="xyz", include_summary=True, include_genres=True, + delta=False, ) library_record = library_marc_record() @@ -200,6 +200,13 @@ def test_library_marc_record(self, annotator_fixture: AnnotatorFixture): library_record, includes={3, 520, 650}, excludes={856} ) + # If the record is part of a delta, then the flag is set + library_record = library_marc_record(delta=False) + assert library_record.leader.record_status == "n" + + library_record = library_marc_record(delta=True) + assert library_record.leader.record_status == "c" + def test_leader(self, annotator_fixture: AnnotatorFixture): leader = annotator_fixture.annotator.leader(False) assert leader == "00000nam 2200000 4500" @@ -273,19 +280,9 @@ def test_add_isbn( annotator_fixture.annotator.add_isbn(record, isbn) annotator_fixture.assert_field(record, "020", {"a": isbn.identifier}) - # If the identifier isn't an ISBN, but has an equivalent that is, it still - # works. - equivalent = db.identifier() - data_source = DataSource.lookup(db.session, DataSource.OCLC) - equivalent.equivalent_to(data_source, isbn, 1) - record = annotator_fixture.record() - annotator_fixture.annotator.add_isbn(record, equivalent) - annotator_fixture.assert_field(record, "020", {"a": isbn.identifier}) - # If there is no ISBN, the field is left out. - non_isbn = db.identifier() record = annotator_fixture.record() - annotator_fixture.annotator.add_isbn(record, non_isbn) + annotator_fixture.annotator.add_isbn(record, None) assert [] == record.get_fields("020") def test_add_title( diff --git a/tests/manager/marc/test_exporter.py b/tests/manager/marc/test_exporter.py index f5c01804a..b5534e9e4 100644 --- a/tests/manager/marc/test_exporter.py +++ b/tests/manager/marc/test_exporter.py @@ -1,45 +1,23 @@ import datetime from functools import partial -from unittest.mock import ANY, call, create_autospec -from uuid import UUID import pytest from freezegun import freeze_time +from sqlalchemy.exc import InvalidRequestError from palace.manager.marc.exporter import LibraryInfo, MarcExporter from palace.manager.marc.settings import MarcExporterLibrarySettings -from palace.manager.marc.uploader import MarcUploadManager from palace.manager.sqlalchemy.model.discovery_service_registration import ( DiscoveryServiceRegistration, ) from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.sqlalchemy.util import create, get_one +from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import datetime_utc, utc_now from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.marc import MarcExporterFixture class TestMarcExporter: - def test__s3_key(self, marc_exporter_fixture: MarcExporterFixture) -> None: - library = marc_exporter_fixture.library1 - collection = marc_exporter_fixture.collection1 - - uuid = UUID("c2370bf2-28e1-40ff-9f04-4864306bd11c") - now = datetime_utc(2024, 8, 27) - since = datetime_utc(2024, 8, 20) - - s3_key = partial(MarcExporter._s3_key, library, collection, now, uuid) - - assert ( - s3_key() - == f"marc/{library.short_name}/{collection.name}.full.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" - ) - - assert ( - s3_key(since_time=since) - == f"marc/{library.short_name}/{collection.name}.delta.2024-08-20.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" - ) - @freeze_time("2020-02-20T10:00:00Z") @pytest.mark.parametrize( "last_updated_time, update_frequency, expected", @@ -235,6 +213,7 @@ def test_enabled_libraries( collection_id=marc_exporter_fixture.collection1.id, ) + assert enabled_libraries(collection_id=None) == [] assert enabled_libraries() == [] # Collections have marc export enabled, and the marc exporter integration is setup, but @@ -266,8 +245,6 @@ def assert_library_2(library_info: LibraryInfo) -> None: assert library_info.include_summary is False assert library_info.include_genres is False assert library_info.web_client_urls == ("http://web-client",) - assert library_info.s3_key_full.startswith("marc/library2/collection1.full") - assert library_info.s3_key_delta is None assert_library_2(library_2_info) @@ -293,8 +270,6 @@ def assert_library_2(library_info: LibraryInfo) -> None: assert library_1_info.include_summary is True assert library_1_info.include_genres is True assert library_1_info.web_client_urls == () - assert library_1_info.s3_key_full.startswith("marc/library1/collection1.full") - assert library_1_info.s3_key_delta is None def test_query_works(self, marc_exporter_fixture: MarcExporterFixture) -> None: assert marc_exporter_fixture.collection1.id is not None @@ -306,11 +281,23 @@ def test_query_works(self, marc_exporter_fixture: MarcExporterFixture) -> None: batch_size=3, ) + assert query_works(collection_id=None) == [] assert query_works() == [] works = marc_exporter_fixture.works() - assert query_works() == works[:3] + result = query_works() + assert result == works[:3] + + # Make sure the loader options are correctly set on the results, this will cause an InvalidRequestError + # to be raised on any attribute access that doesn't have a loader setup. LicensePool.loans is an example + # of an unconfigured attribute. + with pytest.raises( + InvalidRequestError, + match="'LicensePool.loans' is not available due to lazy='raise'", + ): + _ = result[0].license_pools[0].loans + assert query_works(work_id_offset=works[3].id) == works[4:] def test_collection(self, marc_exporter_fixture: MarcExporterFixture) -> None: @@ -329,105 +316,38 @@ def test_collection(self, marc_exporter_fixture: MarcExporterFixture) -> None: def test_process_work(self, marc_exporter_fixture: MarcExporterFixture) -> None: marc_exporter_fixture.configure_export() + marc_exporter_fixture.marc_file( + library=marc_exporter_fixture.library1, + created=utc_now() - datetime.timedelta(days=14), + ) collection = marc_exporter_fixture.collection1 work = marc_exporter_fixture.work(collection) + pool = work.license_pools[0] enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) - mock_upload_manager = create_autospec(MarcUploadManager) - process_work = partial( MarcExporter.process_work, work, + pool, + None, enabled_libraries, "http://base.url", - upload_manager=mock_upload_manager, ) - process_work() - mock_upload_manager.add_record.assert_has_calls( - [ - call(enabled_libraries[0].s3_key_full, ANY), - call(enabled_libraries[0].s3_key_delta, ANY), - call(enabled_libraries[1].s3_key_full, ANY), - ] - ) - - # If the work has no license pools, it is skipped. - mock_upload_manager.reset_mock() - work.license_pools = [] - process_work() - mock_upload_manager.add_record.assert_not_called() - - def test_create_marc_upload_records( - self, marc_exporter_fixture: MarcExporterFixture - ) -> None: - marc_exporter_fixture.configure_export() - - collection = marc_exporter_fixture.collection1 - assert collection.id is not None - enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) - - marc_exporter_fixture.session.query(MarcFile).delete() + # We get both libraries included in a full record + processed_works = process_work(False) + assert list(processed_works.keys()) == enabled_libraries - start_time = utc_now() - - # If there are no uploads, then no records are created. - MarcExporter.create_marc_upload_records( - marc_exporter_fixture.session, - start_time, - collection.id, - enabled_libraries, - set(), - ) - - assert len(marc_exporter_fixture.session.query(MarcFile).all()) == 0 - - # If there are uploads, then records are created. - assert enabled_libraries[0].s3_key_delta is not None - MarcExporter.create_marc_upload_records( - marc_exporter_fixture.session, - start_time, - collection.id, - enabled_libraries, - { - enabled_libraries[0].s3_key_full, - enabled_libraries[1].s3_key_full, - enabled_libraries[0].s3_key_delta, - }, - ) - - assert len(marc_exporter_fixture.session.query(MarcFile).all()) == 3 - - assert get_one( - marc_exporter_fixture.session, - MarcFile, - collection=collection, - library_id=enabled_libraries[0].library_id, - key=enabled_libraries[0].s3_key_full, - ) - - assert get_one( - marc_exporter_fixture.session, - MarcFile, - collection=collection, - library_id=enabled_libraries[1].library_id, - key=enabled_libraries[1].s3_key_full, - ) - - assert get_one( - marc_exporter_fixture.session, - MarcFile, - collection=collection, - library_id=enabled_libraries[0].library_id, - key=enabled_libraries[0].s3_key_delta, - since=enabled_libraries[0].last_updated, - ) + # But we only get library1 in a delta record, since this is the first full marc export + # for library2, so there is no timestamp to create a delta record against. + [processed_work] = process_work(True).keys() + assert processed_work.library_id == marc_exporter_fixture.library1.id def test_files_for_cleanup_deleted_disabled( self, marc_exporter_fixture: MarcExporterFixture ) -> None: - marc_exporter_fixture.configure_export(marc_file=False) + marc_exporter_fixture.configure_export() files_for_cleanup = partial( MarcExporter.files_for_cleanup, marc_exporter_fixture.session, @@ -485,7 +405,7 @@ def test_files_for_cleanup_deleted_disabled( def test_files_for_cleanup_outdated_full( self, marc_exporter_fixture: MarcExporterFixture ) -> None: - marc_exporter_fixture.configure_export(marc_file=False) + marc_exporter_fixture.configure_export() files_for_cleanup = partial( MarcExporter.files_for_cleanup, marc_exporter_fixture.session, @@ -509,7 +429,7 @@ def test_files_for_cleanup_outdated_full( def test_files_for_cleanup_outdated_delta( self, marc_exporter_fixture: MarcExporterFixture ) -> None: - marc_exporter_fixture.configure_export(marc_file=False) + marc_exporter_fixture.configure_export() files_for_cleanup = partial( MarcExporter.files_for_cleanup, marc_exporter_fixture.session, diff --git a/tests/manager/marc/test_uploader.py b/tests/manager/marc/test_uploader.py index ebe983752..0bf1d83fc 100644 --- a/tests/manager/marc/test_uploader.py +++ b/tests/manager/marc/test_uploader.py @@ -1,391 +1,224 @@ -from unittest.mock import MagicMock, call +from functools import partial +from io import BytesIO +from tempfile import TemporaryFile +from unittest.mock import create_autospec +from uuid import UUID, uuid4 import pytest -from celery.exceptions import Ignore, Retry -from palace.manager.marc.uploader import MarcUploadManager -from palace.manager.service.redis.models.marc import ( - MarcFileUpload, - MarcFileUploadSession, +from palace.manager.marc.uploader import ( + MarcUploadException, + MarcUploadManager, + UploadContext, ) -from palace.manager.sqlalchemy.model.resource import Representation -from tests.fixtures.redis import RedisFixture -from tests.fixtures.s3 import S3ServiceFixture, S3ServiceIntegrationFixture +from palace.manager.util.datetime_helpers import datetime_utc +from tests.fixtures.s3 import S3ServiceFixture class MarcUploadManagerFixture: - def __init__( - self, redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture - ): - self._redis_fixture = redis_fixture + def __init__(self, s3_service_fixture: S3ServiceFixture): self._s3_service_fixture = s3_service_fixture - - self.test_key1 = "test.123" - self.test_record1 = b"test_record_123" - self.test_key2 = "test*456" - self.test_record2 = b"test_record_456" - self.test_key3 = "test--?789" - self.test_record3 = b"test_record_789" - self.mock_s3_service = s3_service_fixture.mock_service() + # Reduce the minimum upload size to make testing easier - self.mock_s3_service.MINIMUM_MULTIPART_UPLOAD_SIZE = len(self.test_record1) * 4 - self.redis_client = redis_fixture.client + self.mock_s3_service.MINIMUM_MULTIPART_UPLOAD_SIZE = 4 - self.mock_collection_id = 52 + self.collection_name = "collection" + self.library_short_name = "short_name" + self.creation_time = datetime_utc(year=2001, month=1, day=1) - self.uploads = MarcFileUploadSession(self.redis_client, self.mock_collection_id) - self.uploader = MarcUploadManager(self.mock_s3_service, self.uploads) + self.create_uploader = partial( + MarcUploadManager, + storage_service=self.mock_s3_service, + collection_name=self.collection_name, + library_short_name=self.library_short_name, + creation_time=self.creation_time, + since_time=None, + ) @pytest.fixture -def marc_upload_manager_fixture( - redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture -): - return MarcUploadManagerFixture(redis_fixture, s3_service_fixture) +def marc_upload_manager_fixture(s3_service_fixture: S3ServiceFixture): + return MarcUploadManagerFixture(s3_service_fixture) class TestMarcUploadManager: - def test_begin( - self, - marc_upload_manager_fixture: MarcUploadManagerFixture, - redis_fixture: RedisFixture, - ): - uploader = marc_upload_manager_fixture.uploader + def test__s3_key(self) -> None: + library_short_name = "short" + collection_name = "Palace is great" + uuid = UUID("c2370bf2-28e1-40ff-9f04-4864306bd11c") + now = datetime_utc(2024, 8, 27) + since = datetime_utc(2024, 8, 20) - assert uploader.locked is False - assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False - - with uploader.begin() as u: - # The context manager returns the uploader object - assert u is uploader - - # It directly tells us the lock status - assert uploader.locked is True - - # The lock is also reflected in the uploads object - assert marc_upload_manager_fixture.uploads.locked(by_us=True) is True - - # The lock is released after the context manager exits - assert uploader.locked is False - assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False - - # If an exception occurs, the lock is deleted and the exception is raised by calling - # the _abort method - mock_abort = MagicMock(wraps=uploader._abort) - uploader._abort = mock_abort - with pytest.raises(Exception): - with uploader.begin(): - assert uploader.locked is True - raise Exception() - assert ( - redis_fixture.client.json().get(marc_upload_manager_fixture.uploads.key) - is None - ) - mock_abort.assert_called_once() - - # If a expected celery exception occurs, the lock is released, but not deleted - # and the abort method isn't called - mock_abort.reset_mock() - for exception in Retry, Ignore: - with pytest.raises(exception): - with uploader.begin(): - assert uploader.locked is True - raise exception() - assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False - assert ( - redis_fixture.client.json().get(marc_upload_manager_fixture.uploads.key) - is not None - ) - mock_abort.assert_not_called() - - def test_add_record(self, marc_upload_manager_fixture: MarcUploadManagerFixture): - uploader = marc_upload_manager_fixture.uploader - - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1, + s3_key = partial( + MarcUploadManager._s3_key, library_short_name, collection_name, now, uuid ) + collection_name_no_spaces = collection_name.replace(" ", "_") + assert ( - uploader._buffers[marc_upload_manager_fixture.test_key1] - == marc_upload_manager_fixture.test_record1.decode() + s3_key() + == f"marc/{library_short_name}/{collection_name_no_spaces}.full.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" ) - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1, - ) assert ( - uploader._buffers[marc_upload_manager_fixture.test_key1] - == marc_upload_manager_fixture.test_record1.decode() * 2 + s3_key(since_time=since) + == f"marc/{library_short_name}/{collection_name_no_spaces}.delta.2024-08-20.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" ) - def test_sync(self, marc_upload_manager_fixture: MarcUploadManagerFixture): - uploader = marc_upload_manager_fixture.uploader - - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1, - ) - uploader.add_record( - marc_upload_manager_fixture.test_key2, - marc_upload_manager_fixture.test_record2 * 2, - ) - with uploader.begin(): - uploader.sync() - - # Sync clears the local buffer - assert uploader._buffers == {} - - # And pushes the local records to redis - assert marc_upload_manager_fixture.uploads.get() == { - marc_upload_manager_fixture.test_key1: MarcFileUpload( - buffer=marc_upload_manager_fixture.test_record1 - ), - marc_upload_manager_fixture.test_key2: MarcFileUpload( - buffer=marc_upload_manager_fixture.test_record2 * 2 - ), - } - - # Because the buffer did not contain enough data, it was not uploaded to S3 - assert marc_upload_manager_fixture.mock_s3_service.upload_in_progress == {} - - # Add enough data for test_key1 to be uploaded to S3 - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1 * 2, - ) - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1 * 2, - ) - uploader.add_record( - marc_upload_manager_fixture.test_key2, - marc_upload_manager_fixture.test_record2, + def test__init_(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + # If we initialize with a pre-existing context, the context is set directly + context = UploadContext( + upload_uuid=uuid4(), s3_key="s3_key", upload_id="upload_id" ) + uploader = marc_upload_manager_fixture.create_uploader(context=context) + assert uploader.context is context - with uploader.begin(): - uploader.sync() - - # The buffer is cleared - assert uploader._buffers == {} - - # Because the data for test_key1 was large enough, it was uploaded to S3, and its redis data structure was - # updated to reflect this. test_key2 was not large enough to upload, so it remains in redis and not in s3. - redis_data = marc_upload_manager_fixture.uploads.get() - assert redis_data[marc_upload_manager_fixture.test_key2] == MarcFileUpload( - buffer=marc_upload_manager_fixture.test_record2 * 3 + # If we don't give a context, one is created and set + uploader = marc_upload_manager_fixture.create_uploader() + assert uploader.context.upload_id is None + assert uploader.context.s3_key.startswith( + f"marc/{marc_upload_manager_fixture.library_short_name}/" + f"{marc_upload_manager_fixture.collection_name}.full.2001-01-01." ) - redis_data_test1 = redis_data[marc_upload_manager_fixture.test_key1] - assert redis_data_test1.buffer == "" + assert isinstance(uploader.context.upload_uuid, UUID) + assert uploader.context.parts == [] + def test_begin_upload(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.create_uploader() + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 0 + uploader.begin_upload() assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 1 - assert ( - marc_upload_manager_fixture.test_key1 - in marc_upload_manager_fixture.mock_s3_service.upload_in_progress + [upload] = ( + marc_upload_manager_fixture.mock_s3_service.upload_in_progress.values() ) - upload = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ - marc_upload_manager_fixture.test_key1 - ] - assert upload.upload_id is not None - assert upload.content_type is Representation.MARC_MEDIA_TYPE - [part] = upload.parts.values() - assert part.content == marc_upload_manager_fixture.test_record1 * 5 - - # And the s3 part data and upload_id is synced to redis - assert redis_data_test1.parts == [part.part_data] - assert redis_data_test1.upload_id == upload.upload_id + assert uploader.context.upload_id == upload.upload_id - def test_complete(self, marc_upload_manager_fixture: MarcUploadManagerFixture): - uploader = marc_upload_manager_fixture.uploader + def test_upload_part(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.create_uploader() - # Wrap the clear method so we can check if it was called - mock_clear_uploads = MagicMock( - wraps=marc_upload_manager_fixture.uploads.clear_uploads - ) - marc_upload_manager_fixture.uploads.clear_uploads = mock_clear_uploads + # If begin_upload hasn't been called, it will be called by upload_part + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 0 - # Set up the records for the test - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1 * 5, - ) - uploader.add_record( - marc_upload_manager_fixture.test_key2, - marc_upload_manager_fixture.test_record2 * 5, - ) - with uploader.begin(): - uploader.sync() + # Can upload parts as a binary file, or a byte string + assert uploader.upload_part(b"test") + with TemporaryFile() as f: + f.write(b" another test") + assert uploader.upload_part(f) - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1 * 5, - ) - with uploader.begin(): - uploader.sync() + # Empty parts are ignored + assert not uploader.upload_part(b"") + assert not uploader.upload_part(BytesIO()) - uploader.add_record( - marc_upload_manager_fixture.test_key2, - marc_upload_manager_fixture.test_record2, - ) + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 1 - uploader.add_record( - marc_upload_manager_fixture.test_key3, - marc_upload_manager_fixture.test_record3, + [upload_parts] = ( + marc_upload_manager_fixture.mock_s3_service.upload_in_progress.values() ) + assert len(upload_parts.parts) == 2 - # Complete the uploads - with uploader.begin(): - completed = uploader.complete() - - # The complete method should return the keys that were completed - assert completed == { - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_key2, - marc_upload_manager_fixture.test_key3, - } - - # The local buffers should be empty - assert uploader._buffers == {} + # Complete the upload + assert uploader.complete() + [complete_upload] = marc_upload_manager_fixture.mock_s3_service.uploads.values() + assert complete_upload.content == b"test another test" - # The redis record should have the completed uploads cleared - mock_clear_uploads.assert_called_once() + # Trying to add a part to a complete upload raises an error + with pytest.raises(MarcUploadException, match="Upload is already finalized."): + uploader.upload_part(b"123") - # The s3 service should have the completed uploads - assert len(marc_upload_manager_fixture.mock_s3_service.uploads) == 3 + def test_abort(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + # If an upload hasn't been started abort just sets finalized + uploader = marc_upload_manager_fixture.create_uploader() + uploader.abort() + assert uploader.finalized assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 0 + assert len(marc_upload_manager_fixture.mock_s3_service.aborted) == 0 - test_key1_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ - marc_upload_manager_fixture.test_key1 - ] - assert test_key1_upload.key == marc_upload_manager_fixture.test_key1 - assert test_key1_upload.content == marc_upload_manager_fixture.test_record1 * 10 - assert test_key1_upload.media_type == Representation.MARC_MEDIA_TYPE - - test_key2_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ - marc_upload_manager_fixture.test_key2 - ] - assert test_key2_upload.key == marc_upload_manager_fixture.test_key2 - assert test_key2_upload.content == marc_upload_manager_fixture.test_record2 * 6 - assert test_key2_upload.media_type == Representation.MARC_MEDIA_TYPE - - test_key3_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ - marc_upload_manager_fixture.test_key3 - ] - assert test_key3_upload.key == marc_upload_manager_fixture.test_key3 - assert test_key3_upload.content == marc_upload_manager_fixture.test_record3 - assert test_key3_upload.media_type == Representation.MARC_MEDIA_TYPE - - def test__abort( - self, - marc_upload_manager_fixture: MarcUploadManagerFixture, - caplog: pytest.LogCaptureFixture, - ): - uploader = marc_upload_manager_fixture.uploader - - # Set up the records for the test - uploader.add_record( - marc_upload_manager_fixture.test_key1, - marc_upload_manager_fixture.test_record1 * 10, - ) - uploader.add_record( - marc_upload_manager_fixture.test_key2, - marc_upload_manager_fixture.test_record2 * 10, - ) - with uploader.begin(): - uploader.sync() - - # Mock the multipart_abort method so we can check if it was called and have it - # raise an exception on the first call - mock_abort = MagicMock(side_effect=[Exception("Boom"), None]) - marc_upload_manager_fixture.mock_s3_service.multipart_abort = mock_abort - - # Wrap the delete method so we can check if it was called - mock_delete = MagicMock(wraps=marc_upload_manager_fixture.uploads.delete) - marc_upload_manager_fixture.uploads.delete = mock_delete - - upload_id_1 = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ - marc_upload_manager_fixture.test_key1 - ].upload_id - upload_id_2 = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ - marc_upload_manager_fixture.test_key2 - ].upload_id - - # Abort the uploads, the original exception should propagate, and the exception - # thrown by the first call to abort should be logged - with pytest.raises(Exception) as exc_info: - with uploader.begin(): - raise Exception("Bang") - assert str(exc_info.value) == "Bang" - + # Otherwise abort calls to the API to abort the upload + uploader = marc_upload_manager_fixture.create_uploader() + uploader.begin_upload() + uploader.abort() + assert uploader.finalized assert ( - f"Failed to abort upload {marc_upload_manager_fixture.test_key1} (UploadID: {upload_id_1}) due to exception (Boom)" - in caplog.text - ) - - mock_abort.assert_has_calls( - [ - call(marc_upload_manager_fixture.test_key1, upload_id_1), - call(marc_upload_manager_fixture.test_key2, upload_id_2), - ] + uploader.context.s3_key + in marc_upload_manager_fixture.mock_s3_service.aborted ) - # The redis record should have been deleted - mock_delete.assert_called_once() - - def test_real_storage_service( - self, - redis_fixture: RedisFixture, - s3_service_integration_fixture: S3ServiceIntegrationFixture, - ): - """ - Full end-to-end test of the MarcUploadManager using the real S3Service - """ - s3_service = s3_service_integration_fixture.public - uploads = MarcFileUploadSession(redis_fixture.client, 99) - uploader = MarcUploadManager(s3_service, uploads) - batch_size = s3_service.MINIMUM_MULTIPART_UPLOAD_SIZE + 1 - - with uploader.begin() as locked: - assert locked - - # Test three buffer size cases for the complete() method. - # - # 1. A small record that isn't in S3 at the time `complete` is called (test1). - # 2. A large record that needs to be uploaded in parts. On the first `sync` - # call, its buffer is large enough to trigger an upload. When `complete` is - # called, the buffer has data waiting for upload (test2). - # 3. A large record that needs to be uploaded in parts. On the first `sync` - # call, its buffer is large enough to trigger the upload. When `complete` - # is called, the buffer is empty (test3). - - uploader.add_record("test1", b"test_record") - uploader.add_record("test2", b"a" * batch_size) - uploader.add_record("test3", b"b" * batch_size) - - # Start the sync. This will begin the multipart upload for test2 and test3. - uploader.sync() - - # Add some more data - uploader.add_record("test1", b"test_record") - uploader.add_record("test2", b"a" * batch_size) - - # Complete the uploads - completed = uploader.complete() - - assert completed == {"test1", "test2", "test3"} - assert uploads.get() == {} - assert set(s3_service_integration_fixture.list_objects("public")) == completed + # calling abort again is a no-op + uploader.abort() + def test_complete(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + # If the upload hasn't started, the upload isn't aborted, but it is finalized + uploader = marc_upload_manager_fixture.create_uploader() + assert not uploader.complete() + assert uploader.finalized + + # If the upload has no parts, it is aborted + uploader = marc_upload_manager_fixture.create_uploader() + uploader.begin_upload() + assert not uploader.complete() + assert uploader.finalized assert ( - s3_service_integration_fixture.get_object("public", "test1") - == b"test_record" * 2 + uploader.context.s3_key + in marc_upload_manager_fixture.mock_s3_service.aborted ) + + # Upload with parts is completed + uploader = marc_upload_manager_fixture.create_uploader() + uploader.upload_part(b"test data") + assert uploader.complete() + assert uploader.finalized assert ( - s3_service_integration_fixture.get_object("public", "test2") - == b"a" * batch_size * 2 + uploader.context.s3_key + in marc_upload_manager_fixture.mock_s3_service.uploads ) assert ( - s3_service_integration_fixture.get_object("public", "test3") - == b"b" * batch_size + marc_upload_manager_fixture.mock_s3_service.uploads[ + uploader.context.s3_key + ].content + == b"test data" ) + + # Calling complete a second time raises an exception + with pytest.raises(MarcUploadException, match="Upload is already finalized."): + uploader.complete() + + def test_context_manager( + self, + marc_upload_manager_fixture: MarcUploadManagerFixture, + caplog: pytest.LogCaptureFixture, + ): + # Nesting context manager raises an exception + uploader = marc_upload_manager_fixture.create_uploader() + with uploader: + with pytest.raises( + MarcUploadException, match="Cannot nest MarcUploadManager" + ): + with uploader: + ... + + # The context manager doesn't complete an in-progress upload + with marc_upload_manager_fixture.create_uploader() as uploader: + uploader.upload_part(b"test data") + assert not uploader.finalized + assert len(marc_upload_manager_fixture.mock_s3_service.uploads) == 0 + + # But if there is an exception it cleans up the upload + with pytest.raises(Exception, match="Boom!"): + with uploader: + raise Exception("Boom!") + assert uploader.finalized + assert len(marc_upload_manager_fixture.mock_s3_service.uploads) == 0 + assert len(marc_upload_manager_fixture.mock_s3_service.aborted) == 1 + + # If the exception causes an exception, we just swallow the exception + # and log it, since we are already handing the outer exception. + uploader = marc_upload_manager_fixture.create_uploader() + uploader.abort = create_autospec( + uploader.abort, side_effect=Exception("Another exception") + ) + caplog.clear() + with pytest.raises(Exception, match="Boom!"): + with uploader: + raise Exception("Boom!") + assert "Failed to abort upload" in caplog.text + assert "due to exception (Another exception)." in caplog.text diff --git a/tests/manager/service/redis/models/test_lock.py b/tests/manager/service/redis/models/test_lock.py index c317db7a8..ca7aa956d 100644 --- a/tests/manager/service/redis/models/test_lock.py +++ b/tests/manager/service/redis/models/test_lock.py @@ -1,17 +1,10 @@ from datetime import timedelta -from typing import Any from unittest.mock import create_autospec import pytest from palace.manager.celery.task import Task -from palace.manager.service.redis.models.lock import ( - LockError, - RedisJsonLock, - RedisLock, - TaskLock, -) -from palace.manager.service.redis.redis import Redis +from palace.manager.service.redis.models.lock import LockError, RedisLock, TaskLock from tests.fixtures.redis import RedisFixture @@ -189,147 +182,3 @@ def test___init__(self, redis_fixture: RedisFixture): # If we provide a lock_name, we should use that instead task_lock = TaskLock(redis_fixture.client, mock_task, lock_name="test_lock") assert task_lock.key.endswith("::TaskLock::test_lock") - - -class MockJsonLock(RedisJsonLock): - def __init__( - self, - redis_client: Redis, - key: str = "test", - timeout: int = 1000, - random_value: str | None = None, - ): - self._key = redis_client.get_key(key) - self._timeout = timeout - super().__init__(redis_client, random_value) - - @property - def key(self) -> str: - return self._key - - @property - def _lock_timeout_ms(self) -> int: - return self._timeout - - -class JsonLockFixture: - def __init__(self, redis_fixture: RedisFixture) -> None: - self.client = redis_fixture.client - self.lock = MockJsonLock(redis_fixture.client) - self.other_lock = MockJsonLock(redis_fixture.client) - - def get_key(self, key: str, json_key: str) -> Any: - ret_val = self.client.json().get(key, json_key) - if ret_val is None or len(ret_val) != 1: - return None - return ret_val[0] - - def assert_locked(self, lock: RedisJsonLock) -> None: - assert self.get_key(lock.key, lock._lock_json_key) == lock._random_value - - -@pytest.fixture -def json_lock_fixture(redis_fixture: RedisFixture) -> JsonLockFixture: - return JsonLockFixture(redis_fixture) - - -class TestJsonLock: - def test_acquire(self, json_lock_fixture: JsonLockFixture): - # We can acquire the lock. And acquiring the lock sets a timeout on the key, so the lock - # will expire eventually if something goes wrong. - assert json_lock_fixture.lock.acquire() - assert json_lock_fixture.client.ttl(json_lock_fixture.lock.key) > 0 - json_lock_fixture.assert_locked(json_lock_fixture.lock) - - # Acquiring the lock again with the same random value should return True - # and extend the timeout for the lock - json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) - timeout = json_lock_fixture.client.pttl(json_lock_fixture.lock.key) - assert json_lock_fixture.lock.acquire() - assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) > timeout - - # Acquiring the lock again with a different random value should return False - assert not json_lock_fixture.other_lock.acquire() - json_lock_fixture.assert_locked(json_lock_fixture.lock) - - def test_release(self, json_lock_fixture: JsonLockFixture): - # If the lock doesn't exist, we can't release it - assert json_lock_fixture.lock.release() is False - - # If you acquire a lock another client cannot release it - assert json_lock_fixture.lock.acquire() - assert json_lock_fixture.other_lock.release() is False - - # Make sure the key is set in redis - json_lock_fixture.assert_locked(json_lock_fixture.lock) - - # But the client that acquired the lock can release it - assert json_lock_fixture.lock.release() is True - - # And the key should still exist, but the lock key in the json is removed from redis - assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") == {} - - def test_delete(self, json_lock_fixture: JsonLockFixture): - assert json_lock_fixture.lock.delete() is False - - # If you acquire a lock another client cannot delete it - assert json_lock_fixture.lock.acquire() - assert json_lock_fixture.other_lock.delete() is False - - # Make sure the key is set in redis - assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is not None - json_lock_fixture.assert_locked(json_lock_fixture.lock) - - # But the client that acquired the lock can delete it - assert json_lock_fixture.lock.delete() is True - - # And the key should still exist, but the lock key in the json is removed from redis - assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is None - - def test_extend_timeout(self, json_lock_fixture: JsonLockFixture): - assert json_lock_fixture.lock.extend_timeout() is False - - # If the lock has a timeout, the acquiring client can extend it, but another client cannot - assert json_lock_fixture.lock.acquire() - json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) - assert json_lock_fixture.other_lock.extend_timeout() is False - assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) <= 500 - - # The key should have a new timeout - assert json_lock_fixture.lock.extend_timeout() is True - assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) > 500 - - def test_locked(self, json_lock_fixture: JsonLockFixture): - # If the lock is not acquired, it should not be locked - assert json_lock_fixture.lock.locked() is False - - # If the lock is acquired, it should be locked - assert json_lock_fixture.lock.acquire() - assert json_lock_fixture.lock.locked() is True - assert json_lock_fixture.other_lock.locked() is True - assert json_lock_fixture.lock.locked(by_us=True) is True - assert json_lock_fixture.other_lock.locked(by_us=True) is False - - # If the lock is released, it should not be locked - assert json_lock_fixture.lock.release() is True - assert json_lock_fixture.lock.locked() is False - assert json_lock_fixture.other_lock.locked() is False - - def test__parse_value(self): - assert RedisJsonLock._parse_value(None) is None - assert RedisJsonLock._parse_value([]) is None - assert RedisJsonLock._parse_value(["value"]) == "value" - - def test__parse_multi(self): - assert RedisJsonLock._parse_multi(None) == {} - assert RedisJsonLock._parse_multi({}) == {} - assert RedisJsonLock._parse_multi( - {"key": ["value"], "key2": ["value2"], "key3": []} - ) == {"key": "value", "key2": "value2", "key3": None} - - def test__parse_value_or_raise(self): - with pytest.raises(LockError): - RedisJsonLock._parse_value_or_raise(None) - with pytest.raises(LockError): - RedisJsonLock._parse_value_or_raise([]) - assert RedisJsonLock._parse_value_or_raise(["value"]) == "value" diff --git a/tests/manager/service/redis/models/test_marc.py b/tests/manager/service/redis/models/test_marc.py deleted file mode 100644 index b09e1387c..000000000 --- a/tests/manager/service/redis/models/test_marc.py +++ /dev/null @@ -1,476 +0,0 @@ -import string - -import pytest - -from palace.manager.service.redis.models.marc import ( - MarcFileUpload, - MarcFileUploadSession, - MarcFileUploadSessionError, - MarcFileUploadState, -) -from palace.manager.service.redis.redis import Pipeline -from palace.manager.service.storage.s3 import MultipartS3UploadPart -from tests.fixtures.redis import RedisFixture - - -class MarcFileUploadSessionFixture: - def __init__(self, redis_fixture: RedisFixture): - self._redis_fixture = redis_fixture - - self.mock_collection_id = 1 - - self.uploads = MarcFileUploadSession( - self._redis_fixture.client, self.mock_collection_id - ) - - # Some keys with special characters to make sure they are handled correctly. - self.mock_upload_key_1 = "test/test1/?$xyz.abc" - self.mock_upload_key_2 = "t'est💣/tëst2.\"ext`" - self.mock_upload_key_3 = string.printable - - self.mock_unset_upload_key = "test4" - - self.test_data = { - self.mock_upload_key_1: "test", - self.mock_upload_key_2: "another_test", - self.mock_upload_key_3: "another_another_test", - } - - self.part_1 = MultipartS3UploadPart(etag="abc", part_number=1) - self.part_2 = MultipartS3UploadPart(etag="def", part_number=2) - - def load_test_data(self) -> dict[str, int]: - lock_acquired = False - if not self.uploads.locked(): - self.uploads.acquire() - lock_acquired = True - - return_value = self.uploads.append_buffers(self.test_data) - - if lock_acquired: - self.uploads.release() - - return return_value - - def test_data_records(self, *keys: str) -> dict[str, MarcFileUpload]: - return {key: MarcFileUpload(buffer=self.test_data[key]) for key in keys} - - -@pytest.fixture -def marc_file_upload_session_fixture(redis_fixture: RedisFixture): - return MarcFileUploadSessionFixture(redis_fixture) - - -class TestMarcFileUploadSession: - def test__pipeline( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # Using the _pipeline() context manager makes sure that we hold the lock - with pytest.raises(MarcFileUploadSessionError) as exc_info: - with uploads._pipeline(): - pass - assert "Must hold lock" in str(exc_info.value) - - uploads.acquire() - - # It also checks that the update_number is correct - uploads._update_number = 1 - with pytest.raises(MarcFileUploadSessionError) as exc_info: - with uploads._pipeline(): - pass - assert "Update number mismatch" in str(exc_info.value) - - uploads._update_number = 0 - with uploads._pipeline() as pipe: - # If the lock and update number are correct, we should get a pipeline object - assert isinstance(pipe, Pipeline) - - # We are watching the key for this object, so that we know all the data within the - # transaction is consistent, and we are still holding the lock when the pipeline - # executes - assert pipe.watching is True - - # By default it starts the pipeline transaction - assert pipe.explicit_transaction is True - - # We can also start the pipeline without a transaction - with uploads._pipeline(begin_transaction=False) as pipe: - assert pipe.explicit_transaction is False - - def test__execute_pipeline( - self, - marc_file_upload_session_fixture: MarcFileUploadSessionFixture, - redis_fixture: RedisFixture, - ): - client = redis_fixture.client - uploads = marc_file_upload_session_fixture.uploads - uploads.acquire() - - # If we try to execute a pipeline without a transaction, we should get an error - with pytest.raises(MarcFileUploadSessionError) as exc_info: - with uploads._pipeline(begin_transaction=False) as pipe: - uploads._execute_pipeline(pipe, 0) - assert "Pipeline should be in explicit transaction mode" in str(exc_info.value) - - # The _execute_pipeline function takes care of extending the timeout and incrementing - # the update number and setting the state of the session - [update_number] = client.json().get( - uploads.key, uploads._update_number_json_key - ) - client.pexpire(uploads.key, 500) - old_state = uploads.state() - with uploads._pipeline() as pipe: - # If we execute the pipeline, we should get a list of results, excluding the - # operations that _execute_pipeline does. - assert uploads._execute_pipeline(pipe, 2) == [] - [new_update_number] = client.json().get( - uploads.key, uploads._update_number_json_key - ) - assert new_update_number == update_number + 2 - assert client.pttl(uploads.key) > 500 - assert uploads.state() != old_state - assert uploads.state() == MarcFileUploadState.UPLOADING - - # If we try to execute a pipeline that has been modified by another process, we should get an error - with uploads._pipeline() as pipe: - client.json().set( - uploads.key, uploads._update_number_json_key, update_number - ) - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads._execute_pipeline(pipe, 1) - assert "Another process is modifying the buffers" in str(exc_info.value) - - def test_append_buffers( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # If we try to update buffers without acquiring the lock, we should get an error - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.append_buffers( - {marc_file_upload_session_fixture.mock_upload_key_1: "test"} - ) - assert "Must hold lock" in str(exc_info.value) - - # Acquire the lock and try to update buffers - with uploads.lock() as locked: - assert locked - assert uploads.append_buffers({}) == {} - - assert uploads.append_buffers( - { - marc_file_upload_session_fixture.mock_upload_key_1: "test", - marc_file_upload_session_fixture.mock_upload_key_2: "another_test", - } - ) == { - marc_file_upload_session_fixture.mock_upload_key_1: 4, - marc_file_upload_session_fixture.mock_upload_key_2: 12, - } - assert uploads._update_number == 2 - - assert uploads.append_buffers( - { - marc_file_upload_session_fixture.mock_upload_key_1: "x", - marc_file_upload_session_fixture.mock_upload_key_2: "y", - marc_file_upload_session_fixture.mock_upload_key_3: "new", - } - ) == { - marc_file_upload_session_fixture.mock_upload_key_1: 5, - marc_file_upload_session_fixture.mock_upload_key_2: 13, - marc_file_upload_session_fixture.mock_upload_key_3: 3, - } - assert uploads._update_number == 5 - - # If we try to update buffers with an old update number, we should get an error - uploads._update_number = 4 - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.append_buffers(marc_file_upload_session_fixture.test_data) - assert "Update number mismatch" in str(exc_info.value) - - # Exiting the context manager should release the lock - assert not uploads.locked() - - def test_get(self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture): - uploads = marc_file_upload_session_fixture.uploads - - assert uploads.get() == {} - assert uploads.get(marc_file_upload_session_fixture.mock_upload_key_1) == {} - - marc_file_upload_session_fixture.load_test_data() - - # You don't need to acquire the lock to get the uploads, but you should if you - # are using the data to do updates. - - # You can get a subset of the uploads - assert uploads.get( - marc_file_upload_session_fixture.mock_upload_key_1, - ) == marc_file_upload_session_fixture.test_data_records( - marc_file_upload_session_fixture.mock_upload_key_1 - ) - - # Or multiple uploads, any that don't exist are not included in the result dict - assert uploads.get( - [ - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.mock_upload_key_2, - marc_file_upload_session_fixture.mock_unset_upload_key, - ] - ) == marc_file_upload_session_fixture.test_data_records( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.mock_upload_key_2, - ) - - # Or you can get all the uploads - assert uploads.get() == marc_file_upload_session_fixture.test_data_records( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.mock_upload_key_2, - marc_file_upload_session_fixture.mock_upload_key_3, - ) - - def test_set_upload_id( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # must hold lock to do update - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.set_upload_id( - marc_file_upload_session_fixture.mock_upload_key_1, "xyz" - ) - assert "Must hold lock" in str(exc_info.value) - - uploads.acquire() - - # We are unable to set an upload id for an item that hasn't been initialized - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.set_upload_id( - marc_file_upload_session_fixture.mock_upload_key_1, "xyz" - ) - assert "Failed to set upload ID" in str(exc_info.value) - - marc_file_upload_session_fixture.load_test_data() - uploads.set_upload_id(marc_file_upload_session_fixture.mock_upload_key_1, "def") - uploads.set_upload_id(marc_file_upload_session_fixture.mock_upload_key_2, "abc") - - all_uploads = uploads.get() - assert ( - all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].upload_id - == "def" - ) - assert ( - all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].upload_id - == "abc" - ) - - # We can't change the upload id for a library that has already been set - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.set_upload_id( - marc_file_upload_session_fixture.mock_upload_key_1, "ghi" - ) - assert "Failed to set upload ID" in str(exc_info.value) - - all_uploads = uploads.get() - assert ( - all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].upload_id - == "def" - ) - assert ( - all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].upload_id - == "abc" - ) - - def test_clear_uploads( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # must hold lock to do update - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.clear_uploads() - assert "Must hold lock" in str(exc_info.value) - - uploads.acquire() - - # We are unable to clear the uploads for an item that hasn't been initialized - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.clear_uploads() - assert "Failed to clear uploads" in str(exc_info.value) - - marc_file_upload_session_fixture.load_test_data() - assert uploads.get() != {} - - uploads.clear_uploads() - assert uploads.get() == {} - - def test_get_upload_ids( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # If the id is not set, we should get None - assert uploads.get_upload_ids( - [marc_file_upload_session_fixture.mock_upload_key_1] - ) == {marc_file_upload_session_fixture.mock_upload_key_1: None} - - marc_file_upload_session_fixture.load_test_data() - - # If the buffer has been set, but the upload id has not, we should still get None - assert uploads.get_upload_ids( - [marc_file_upload_session_fixture.mock_upload_key_1] - ) == {marc_file_upload_session_fixture.mock_upload_key_1: None} - - with uploads.lock() as locked: - assert locked - uploads.set_upload_id( - marc_file_upload_session_fixture.mock_upload_key_1, "abc" - ) - uploads.set_upload_id( - marc_file_upload_session_fixture.mock_upload_key_2, "def" - ) - assert uploads.get_upload_ids( - marc_file_upload_session_fixture.mock_upload_key_1 - ) == {marc_file_upload_session_fixture.mock_upload_key_1: "abc"} - assert uploads.get_upload_ids( - [ - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.mock_upload_key_2, - ] - ) == { - marc_file_upload_session_fixture.mock_upload_key_1: "abc", - marc_file_upload_session_fixture.mock_upload_key_2: "def", - } - - def test_add_part_and_clear_buffer( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # If we try to add parts without acquiring the lock, we should get an error - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.part_1, - ) - assert "Must hold lock" in str(exc_info.value) - - # Acquire the lock - uploads.acquire() - - # We are unable to add parts to a library whose buffers haven't been initialized - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.part_1, - ) - assert "Failed to add part and clear buffer" in str(exc_info.value) - - marc_file_upload_session_fixture.load_test_data() - - # We are able to add parts to a library that exists - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.part_1, - ) - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_2, - marc_file_upload_session_fixture.part_1, - ) - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.part_2, - ) - - all_uploads = uploads.get() - # The parts are added in order and the buffers are cleared - assert all_uploads[ - marc_file_upload_session_fixture.mock_upload_key_1 - ].parts == [ - marc_file_upload_session_fixture.part_1, - marc_file_upload_session_fixture.part_2, - ] - assert all_uploads[ - marc_file_upload_session_fixture.mock_upload_key_2 - ].parts == [marc_file_upload_session_fixture.part_1] - assert ( - all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].buffer == "" - ) - assert ( - all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].buffer == "" - ) - - def test_get_part_num_and_buffer( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # If the key has not been initialized, we get an exception - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.get_part_num_and_buffer( - marc_file_upload_session_fixture.mock_upload_key_1 - ) - assert "Failed to get part number and buffer data" in str(exc_info.value) - - marc_file_upload_session_fixture.load_test_data() - - # If the buffer has been set, but no parts have been added. The first part number - # should be 1. The buffer should be the same as the original data. - assert uploads.get_part_num_and_buffer( - marc_file_upload_session_fixture.mock_upload_key_1 - ) == ( - 1, - marc_file_upload_session_fixture.test_data[ - marc_file_upload_session_fixture.mock_upload_key_1 - ], - ) - - with uploads.lock() as locked: - assert locked - # Add part 1 - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.part_1, - ) - # Add part 2 - uploads.add_part_and_clear_buffer( - marc_file_upload_session_fixture.mock_upload_key_1, - marc_file_upload_session_fixture.part_2, - ) - uploads.append_buffers( - { - marc_file_upload_session_fixture.mock_upload_key_1: "1234567", - } - ) - - # The next part number should be 3, and the buffer should be the new data - assert uploads.get_part_num_and_buffer( - marc_file_upload_session_fixture.mock_upload_key_1 - ) == (3, "1234567") - - def test_state( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # If the session doesn't exist, the state should be None - assert uploads.state() is None - - # Once the state is created, by locking for example, the state should be SessionState.INITIAL - with uploads.lock(): - assert uploads.state() == MarcFileUploadState.INITIAL - - def test_set_state( - self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture - ): - uploads = marc_file_upload_session_fixture.uploads - - # If we don't hold the lock, we can't set the state - with pytest.raises(MarcFileUploadSessionError) as exc_info: - uploads.set_state(MarcFileUploadState.UPLOADING) - assert "Must hold lock" in str(exc_info.value) - - # Once the state is created, by locking for example, we can set the state - with uploads.lock(): - uploads.set_state(MarcFileUploadState.UPLOADING) - assert uploads.state() == MarcFileUploadState.UPLOADING diff --git a/tests/manager/service/redis/test_escape.py b/tests/manager/service/redis/test_escape.py deleted file mode 100644 index eae33fc99..000000000 --- a/tests/manager/service/redis/test_escape.py +++ /dev/null @@ -1,72 +0,0 @@ -import json -import re -import string - -import pytest - -from palace.manager.core.exceptions import PalaceValueError -from palace.manager.service.redis.escape import JsonPathEscapeMixin - - -class TestPathEscapeMixin: - @pytest.mark.parametrize( - "path", - [ - "", - "test", - string.printable, - "test/test1/?$xyz.abc", - "`", - "```", - "/~`\\", - "`\\~/``/", - "a", - "/", - "~", - " ", - '"', - "💣ü", - ], - ) - def test_escape_path(self, path: str) -> None: - # Test a round trip - escaper = JsonPathEscapeMixin() - escaped = escaper._escape_path(path) - unescaped = escaper._unescape_path(escaped) - assert unescaped == path - - # Test a round trip with ElastiCache escaping. The json.loads is done implicitly by ElastiCache, - # when using these strings in a JsonPath query. We add a json.loads here to simulate that. - escaped = escaper._escape_path(path, elasticache=True) - unescaped = escaper._unescape_path(json.loads(f'"{escaped}"')) - assert unescaped == path - - # Test that we can handle escaping the escaped path multiple times - escaped = path - for _ in range(10): - escaped = escaper._escape_path(escaped) - - unescaped = escaped - for _ in range(10): - unescaped = escaper._unescape_path(unescaped) - - assert unescaped == path - - def test_unescape(self) -> None: - escaper = JsonPathEscapeMixin() - assert escaper._unescape_path("") == "" - - with pytest.raises( - PalaceValueError, match=re.escape("Invalid escape sequence '`?'") - ): - escaper._unescape_path("test `?") - - with pytest.raises( - PalaceValueError, match=re.escape("Invalid escape sequence '` '") - ): - escaper._unescape_path("``` test") - - with pytest.raises( - PalaceValueError, match=re.escape("Unterminated escape sequence") - ): - escaper._unescape_path("`") diff --git a/tests/manager/sqlalchemy/model/test_identifier.py b/tests/manager/sqlalchemy/model/test_identifier.py index c0eb9ed13..2ed805a84 100644 --- a/tests/manager/sqlalchemy/model/test_identifier.py +++ b/tests/manager/sqlalchemy/model/test_identifier.py @@ -1,18 +1,23 @@ import datetime -from unittest.mock import PropertyMock, create_autospec +from unittest.mock import MagicMock, PropertyMock, create_autospec import pytest +from palace.manager.core.equivalents_coverage import ( + EquivalentIdentifiersCoverageProvider, +) from palace.manager.sqlalchemy.constants import MediaTypes from palace.manager.sqlalchemy.model.datasource import DataSource from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.model.identifier import ( + Equivalency, Identifier, ProQuestIdentifierParser, RecursiveEquivalencyCache, ) from palace.manager.sqlalchemy.model.resource import Hyperlink from palace.manager.sqlalchemy.presentation import PresentationCalculationPolicy +from palace.manager.sqlalchemy.util import create from tests.fixtures.database import DatabaseTransactionFixture from tests.manager.sqlalchemy.model.test_coverage import ( ExampleEquivalencyCoverageRecordFixture, @@ -765,6 +770,96 @@ def test_identifier_delete_cascade_parent( all_recursives = session.query(RecursiveEquivalencyCache).all() assert len(all_recursives) == 3 + def test_equivalent_identifiers(self, db: DatabaseTransactionFixture) -> None: + # This is already an isbn so no lookup necessary + isbn_identifier = db.identifier(identifier_type=Identifier.ISBN) + + # Overdrive identifier with two ISBN associated + overdrive_identifier = db.identifier(identifier_type=Identifier.OVERDRIVE_ID) + od_isbn_1 = db.identifier(identifier_type=Identifier.ISBN) + od_isbn_2 = db.identifier(identifier_type=Identifier.ISBN) + create( + db.session, + Equivalency, + input_id=overdrive_identifier.id, + output_id=od_isbn_1.id, + strength=5, + ) + create( + db.session, + Equivalency, + input_id=overdrive_identifier.id, + output_id=od_isbn_2.id, + strength=1, + ) + + # Gutenberg ID with one associated ISBN + gutenberg_identifier = db.identifier(identifier_type=Identifier.GUTENBERG_ID) + gutenberg_isbn = db.identifier(identifier_type=Identifier.ISBN) + create( + db.session, + Equivalency, + input_id=gutenberg_identifier.id, + output_id=gutenberg_isbn.id, + strength=5, + ) + + # Proquest ID with no associated ISBN but an associated GUTENBERG_ID + proquest_identfier = db.identifier(identifier_type=Identifier.PROQUEST_ID) + proquest_gutenberg = db.identifier(identifier_type=Identifier.GUTENBERG_ID) + create( + db.session, + Equivalency, + input_id=proquest_identfier.id, + output_id=proquest_gutenberg.id, + strength=2, + ) + + # We're using the RecursiveEquivalencyCache, so must refresh it. + EquivalentIdentifiersCoverageProvider(db.session).run() + + # Calling with only identifiers of the specified type doesn't do a query, + # it just returns the identifiers + assert RecursiveEquivalencyCache.equivalent_identifiers( + MagicMock(side_effect=Exception("Should not be called")), + { + isbn_identifier, + }, + type=Identifier.ISBN, + ) == {isbn_identifier: isbn_identifier} + + equivalent_isbns = RecursiveEquivalencyCache.equivalent_identifiers( + db.session, + { + isbn_identifier, + overdrive_identifier, + gutenberg_identifier, + proquest_identfier, + }, + type=Identifier.ISBN, + ) + assert equivalent_isbns == { + isbn_identifier: isbn_identifier, + overdrive_identifier: od_isbn_1, + gutenberg_identifier: gutenberg_isbn, + } + + equivalents = RecursiveEquivalencyCache.equivalent_identifiers( + db.session, + { + isbn_identifier, + overdrive_identifier, + gutenberg_identifier, + proquest_identfier, + }, + ) + assert equivalents == { + isbn_identifier: isbn_identifier, + overdrive_identifier: od_isbn_1, + gutenberg_identifier: gutenberg_isbn, + proquest_identfier: proquest_gutenberg, + } + class TestProQuestIdentifierParser: @pytest.mark.parametrize(