From 8db8da13907464654e6373dd03619cd1776a4299 Mon Sep 17 00:00:00 2001 From: Daniel Bernstein Date: Tue, 3 Dec 2024 12:00:55 -0800 Subject: [PATCH] Create separate tasks for reaping holds by collection. Also use with_for_update to lock expired holds before deletion Also ensure that transaction is closed before collecting analytics events. --- src/palace/manager/api/monitor.py | 24 +++-- src/palace/manager/celery/tasks/opds_odl.py | 101 ++++++++++++++------ src/palace/manager/core/monitor.py | 14 +-- tests/manager/celery/tasks/test_opds_odl.py | 68 +++++++------ 4 files changed, 131 insertions(+), 76 deletions(-) diff --git a/src/palace/manager/api/monitor.py b/src/palace/manager/api/monitor.py index b22014a36..ed112a6ee 100644 --- a/src/palace/manager/api/monitor.py +++ b/src/palace/manager/api/monitor.py @@ -15,6 +15,10 @@ class LoanlikeReaperMonitor(ReaperMonitor): OPDSForDistributorsAPI.label(), ] + def __init__(self, *args, **kwargs): + super().__init__(args, kwargs) + self._events_to_be_logged = [] + @property def where_clause(self): """We never want to automatically reap loans or holds for situations @@ -45,20 +49,24 @@ def where_clause(self): ) return ~self.MODEL_CLASS.id.in_(source_of_truth_subquery) - def post_delete(self, row: Loan | Hold) -> None: + def delete(self, row) -> None: ce = CirculationEvent - event_type = ( - CirculationEvent.CM_LOAN_EXPIRED - if isinstance(row, Loan) - else CirculationEvent.CM_HOLD_EXPIRED - ) - - self.services.analytics.collect_event( + event_type = ce.CM_LOAN_EXPIRED if isinstance(row, Loan) else ce.CM_HOLD_EXPIRED + event = dict( library=row.library, license_pool=row.license_pool, event_type=event_type, patron=row.patron, ) + super().delete(row) + self.events_to_be_logged.append(event) + + def after_commit(self) -> None: + super().after_commit() + copy_of_list = list(self._events_to_be_logged) + for event in copy_of_list: + self.services.analytics.collect_event(**event) + self._events_to_be_logged.remove(event) class LoanReaper(LoanlikeReaperMonitor): diff --git a/src/palace/manager/celery/tasks/opds_odl.py b/src/palace/manager/celery/tasks/opds_odl.py index 325631901..cce437386 100644 --- a/src/palace/manager/celery/tasks/opds_odl.py +++ b/src/palace/manager/celery/tasks/opds_odl.py @@ -1,4 +1,5 @@ import datetime +from typing import Any from celery import shared_task from sqlalchemy import delete, select @@ -7,7 +8,6 @@ from palace.manager.api.odl.api import OPDS2WithODLApi from palace.manager.celery.task import Task -from palace.manager.service.analytics.analytics import Analytics from palace.manager.service.celery.celery import QueueNames from palace.manager.service.redis.models.lock import RedisLock from palace.manager.service.redis.redis import Redis @@ -19,13 +19,31 @@ def remove_expired_holds_for_collection( - db: Session, collection_id: int, analytics: Analytics -) -> int: + db: Session, + collection_id: int, +) -> tuple[int, dict[str, Any]]: """ Remove expired holds from the database for this collection. """ # generate expiration events for expired holds before deleting them + # lock rows + lock_query = ( + select(Hold.id) + .where( + Hold.position == 0, + Hold.end < utc_now(), + Hold.license_pool_id == LicensePool.id, + LicensePool.collection_id == collection_id, + ) + .with_for_update() + ) + + db.execute(lock_query).all() + + # a separate query is required to get around the + # "FOR UPDATE cannot be applied to the nullable side of an outer join" issue when trying to use with_for_update + # on the Hold object. select_query = select(Hold).where( Hold.position == 0, Hold.end < utc_now(), @@ -34,13 +52,17 @@ def remove_expired_holds_for_collection( ) expired_holds = db.scalars(select_query).all() + expired_hold_events: [dict[str, Any]] = [] for hold in expired_holds: - analytics.collect_event( - library=hold.library, - license_pool=hold.license_pool, - event_type=CirculationEvent.CM_HOLD_EXPIRED, - patron=hold.patron, + expired_hold_events.append( + dict( + library=hold.library, + license_pool=hold.license_pool, + event_type=CirculationEvent.CM_HOLD_EXPIRED, + patron=hold.patron, + ) ) + # delete the holds query = ( delete(Hold) @@ -53,11 +75,12 @@ def remove_expired_holds_for_collection( .execution_options(synchronize_session="fetch") ) result = db.execute(query) + # We need the type ignores here because result doesn't always have # a rowcount, but the sqlalchemy docs swear it will in the case of # a delete statement. # https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#getting-affected-row-count-from-update-delete - return result.rowcount # type: ignore[attr-defined,no-any-return] + return result.rowcount, expired_hold_events # type: ignore[attr-defined,no-any-return] def licensepool_ids_with_holds( @@ -96,8 +119,7 @@ def lock_licenses(license_pool: LicensePool) -> None: def recalculate_holds_for_licensepool( license_pool: LicensePool, reservation_period: datetime.timedelta, - analytics: Analytics, -) -> int: +) -> tuple[int, dict[str, Any]]: # We take out row level locks on all the licenses and holds for this license pool, so that # everything is in a consistent state while we update the hold queue. This means we should be # quickly committing the transaction, to avoid contention or deadlocks. @@ -111,6 +133,8 @@ def recalculate_holds_for_licensepool( waiting = holds[reserved:] updated = 0 + events: [dict[str, Any]] = [] + # These holds have a copy reserved for them. for hold in ready: # If this hold isn't already in position 0, the hold just became available. @@ -119,11 +143,13 @@ def recalculate_holds_for_licensepool( hold.position = 0 hold.end = utc_now() + reservation_period updated += 1 - analytics.collect_event( - library=hold.library, - license_pool=hold.license_pool, - event_type=CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT, - patron=hold.patron, + events.append( + dict( + library=hold.library, + license_pool=hold.license_pool, + event_type=CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT, + patron=hold.patron, + ) ) # Update the position for the remaining holds. @@ -134,17 +160,37 @@ def recalculate_holds_for_licensepool( hold.end = None updated += 1 - return updated + return updated, events + + +@shared_task(queue=QueueNames.default, bind=True) +def remove_expired_holds_for_collection_task(task: Task, collection_id: int) -> None: + """ + A shared task for removing expired holds from the database for a collection + """ + analytics = task.services.analytics.analytics() + with task.transaction() as session: + collection = Collection.by_id(session, collection_id) + removed, events = remove_expired_holds_for_collection( + session, + collection_id, + ) + task.log.info( + f"Removed {removed} expired holds for collection {collection.name} ({collection_id})." + ) + + # publish events only after successful commit + for event in events: + analytics.collect_event(**event) @shared_task(queue=QueueNames.default, bind=True) def remove_expired_holds(task: Task) -> None: """ - Remove expired holds from the database. + Issue remove expired hold tasks for eligible collections """ registry = task.services.integration_registry.license_providers() protocols = registry.get_protocols(OPDS2WithODLApi, default=False) - analytics = task.services.analytics.analytics() with task.session() as session: collections = [ (collection.id, collection.name) @@ -152,13 +198,7 @@ def remove_expired_holds(task: Task) -> None: if collection.id is not None ] for collection_id, collection_name in collections: - with task.transaction() as session: - removed = remove_expired_holds_for_collection( - session, collection_id, analytics - ) - task.log.info( - f"Removed {removed} expired holds for collection {collection_name} ({collection_id})." - ) + remove_expired_holds_for_collection.delay(collection_id) @shared_task(queue=QueueNames.default, bind=True) @@ -191,6 +231,7 @@ def recalculate_hold_queue_collection( Recalculate the hold queue for a collection. """ lock = _redis_lock_recalculate_holds(task.services.redis.client(), collection_id) + analytics = task.services.analytics.analytics() with lock.lock() as locked: if not locked: task.log.info( @@ -233,11 +274,9 @@ def recalculate_hold_queue_collection( ) continue - analytics = task.services.analytics.analytics() - updated = recalculate_holds_for_licensepool( + updated, events = recalculate_holds_for_licensepool( license_pool, reservation_period, - analytics, ) edition = license_pool.presentation_edition title = edition.title if edition else None @@ -247,6 +286,10 @@ def recalculate_hold_queue_collection( f"{updated} holds out of date." ) + # fire events after successful database update + for event in events: + analytics.collect_event(**event) + if len(license_pool_ids) == batch_size: # We are done this batch, but there is probably more work to do, we queue up the next batch. raise task.replace( diff --git a/src/palace/manager/core/monitor.py b/src/palace/manager/core/monitor.py index 914f61276..444f04579 100644 --- a/src/palace/manager/core/monitor.py +++ b/src/palace/manager/core/monitor.py @@ -880,22 +880,21 @@ def run_once(self, *args, **kwargs): count = qu.count() self.log.info("Deleting %d row(s)", count) while count > 0: - post_delete_ops = [] for i in qu.limit(self.BATCH_SIZE): self.log.info("Deleting %r", i) self.delete(i) - self.post_delete(i) rows_deleted += 1 - self._db.commit() - for op in post_delete_ops: - op() + self.after_commit() count = qu.count() return TimestampData(achievements="Items deleted: %d" % rows_deleted) - def delete(self, row): + def after_commit(self) -> None: + return None + + def delete(self, row) -> None: """Delete a row from the database. CAUTION: If you override this method such that it doesn't @@ -904,9 +903,6 @@ def delete(self, row): """ self._db.delete(row) - def post_delete(self, row) -> None: - return None - def query(self): return self._db.query(self.MODEL_CLASS).filter(self.where_clause) diff --git a/tests/manager/celery/tasks/test_opds_odl.py b/tests/manager/celery/tasks/test_opds_odl.py index 12d996832..6ffe09a29 100644 --- a/tests/manager/celery/tasks/test_opds_odl.py +++ b/tests/manager/celery/tasks/test_opds_odl.py @@ -17,6 +17,7 @@ recalculate_holds_for_licensepool, remove_expired_holds, remove_expired_holds_for_collection, + remove_expired_holds_for_collection_task, ) from palace.manager.service.logging.configuration import LogLevel from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent @@ -152,12 +153,11 @@ def test_remove_expired_holds_for_collection( select(func.count()).select_from(LicensePool) ).one() - analytics = opds_task_fixture.services.analytics_fixture.analytics_mock - # Remove the expired holds assert collection.id is not None - removed = remove_expired_holds_for_collection( - db.session, collection.id, analytics=analytics + removed, events = remove_expired_holds_for_collection( + db.session, + collection.id, ) # Assert that the correct holds were removed @@ -178,10 +178,9 @@ def test_remove_expired_holds_for_collection( assert pools_before == pools_after # verify that the correct analytics calls were made - call_args_list = analytics.collect_event.call_args_list - assert len(call_args_list) == 10 - for call_args in call_args_list: - assert call_args.kwargs["event_type"] == CirculationEvent.CM_HOLD_EXPIRED + assert len(events) == 10 + for event in events: + assert event["event_type"] == CirculationEvent.CM_HOLD_EXPIRED def test_licensepools_with_holds( @@ -231,7 +230,7 @@ def test_recalculate_holds_for_licensepool( analytics = opds_task_fixture.services.analytics_fixture.analytics_mock # Recalculate the hold queue - recalculate_holds_for_licensepool(pool, timedelta(days=5), analytics=analytics) + recalculate_holds_for_licensepool(pool, timedelta(days=5)) current_holds = pool.get_active_holds() assert len(current_holds) == 20 @@ -242,7 +241,7 @@ def test_recalculate_holds_for_licensepool( license1.checkouts_available = 1 license2.checkouts_available = 2 reservation_time = timedelta(days=5) - recalculate_holds_for_licensepool(pool, reservation_time, analytics) + _, events = recalculate_holds_for_licensepool(pool, reservation_time) assert pool.licenses_reserved == 3 assert pool.licenses_available == 0 @@ -271,42 +270,51 @@ def test_recalculate_holds_for_licensepool( ) assert hold.start and expected_start and hold.start >= expected_start - # verify that the correct analytics calls were made - call_args_list = analytics.collect_event.call_args_list - assert len(call_args_list) == 3 - for call_args in call_args_list: - assert ( - call_args.kwargs["event_type"] - == CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT - ) + # verify that the correct analytics events were returned + assert len(events) == 3 + for event in events: + assert event["event_type"] == CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT -def test_remove_expired_holds( +def test_remove_expired_holds_for_collection( celery_fixture: CeleryFixture, db: DatabaseTransactionFixture, opds_task_fixture: OpdsTaskFixture, ): collection1 = db.collection(protocol=OPDS2WithODLApi) - collection2 = db.collection(protocol=OPDS2WithODLApi) - decoy_collection = db.collection(protocol=OverdriveAPI) expired_holds1, non_expired_holds1 = opds_task_fixture.holds(collection1) - expired_holds2, non_expired_holds2 = opds_task_fixture.holds(collection2) - decoy_expired_holds, decoy_non_expired_holds = opds_task_fixture.holds( - decoy_collection - ) # Remove the expired holds - remove_expired_holds.delay().wait() + remove_expired_holds_for_collection_task.delay(collection1.id).wait() + + assert len( + opds_task_fixture.services.analytics_fixture.analytics_mock.method_calls + ) == len(expired_holds1) current_holds = {h.id for h in db.session.scalars(select(Hold))} assert expired_holds1.isdisjoint(current_holds) - assert expired_holds2.isdisjoint(current_holds) - assert decoy_non_expired_holds.issubset(current_holds) - assert decoy_expired_holds.issubset(current_holds) assert non_expired_holds1.issubset(current_holds) - assert non_expired_holds2.issubset(current_holds) + + +def test_remove_expired_holds( + celery_fixture: CeleryFixture, + redis_fixture: RedisFixture, + db: DatabaseTransactionFixture, + opds_task_fixture: OpdsTaskFixture, +): + collection1 = db.collection(protocol=OPDS2WithODLApi) + collection2 = db.collection(protocol=OPDS2WithODLApi) + decoy_collection = db.collection(protocol=OverdriveAPI) + + with patch.object(opds_odl, "remove_expired_holds_for_collection") as mock_remove: + remove_expired_holds.delay().wait() + + assert mock_remove.delay.call_count == 2 + mock_remove.delay.assert_has_calls( + [call(collection1.id), call(collection2.id)], any_order=True + ) def test_recalculate_hold_queue(