Skip to content

Commit

Permalink
Create separate tasks for reaping holds by collection.
Browse files Browse the repository at this point in the history
Also use with_for_update to lock expired holds before deletion
Also ensure that transaction is closed before collecting analytics events.
  • Loading branch information
dbernstein committed Dec 3, 2024
1 parent 1efe444 commit 8db8da1
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 76 deletions.
24 changes: 16 additions & 8 deletions src/palace/manager/api/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class LoanlikeReaperMonitor(ReaperMonitor):
OPDSForDistributorsAPI.label(),
]

def __init__(self, *args, **kwargs):
super().__init__(args, kwargs)
self._events_to_be_logged = []

@property
def where_clause(self):
"""We never want to automatically reap loans or holds for situations
Expand Down Expand Up @@ -45,20 +49,24 @@ def where_clause(self):
)
return ~self.MODEL_CLASS.id.in_(source_of_truth_subquery)

def post_delete(self, row: Loan | Hold) -> None:
def delete(self, row) -> None:
ce = CirculationEvent
event_type = (
CirculationEvent.CM_LOAN_EXPIRED
if isinstance(row, Loan)
else CirculationEvent.CM_HOLD_EXPIRED
)

self.services.analytics.collect_event(
event_type = ce.CM_LOAN_EXPIRED if isinstance(row, Loan) else ce.CM_HOLD_EXPIRED
event = dict(
library=row.library,
license_pool=row.license_pool,
event_type=event_type,
patron=row.patron,
)
super().delete(row)
self.events_to_be_logged.append(event)

def after_commit(self) -> None:
super().after_commit()
copy_of_list = list(self._events_to_be_logged)
for event in copy_of_list:
self.services.analytics.collect_event(**event)
self._events_to_be_logged.remove(event)


class LoanReaper(LoanlikeReaperMonitor):
Expand Down
101 changes: 72 additions & 29 deletions src/palace/manager/celery/tasks/opds_odl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from typing import Any

from celery import shared_task
from sqlalchemy import delete, select
Expand All @@ -7,7 +8,6 @@

from palace.manager.api.odl.api import OPDS2WithODLApi
from palace.manager.celery.task import Task
from palace.manager.service.analytics.analytics import Analytics
from palace.manager.service.celery.celery import QueueNames
from palace.manager.service.redis.models.lock import RedisLock
from palace.manager.service.redis.redis import Redis
Expand All @@ -19,13 +19,31 @@


def remove_expired_holds_for_collection(
db: Session, collection_id: int, analytics: Analytics
) -> int:
db: Session,
collection_id: int,
) -> tuple[int, dict[str, Any]]:
"""
Remove expired holds from the database for this collection.
"""

# generate expiration events for expired holds before deleting them
# lock rows
lock_query = (
select(Hold.id)
.where(
Hold.position == 0,
Hold.end < utc_now(),
Hold.license_pool_id == LicensePool.id,
LicensePool.collection_id == collection_id,
)
.with_for_update()
)

db.execute(lock_query).all()

# a separate query is required to get around the
# "FOR UPDATE cannot be applied to the nullable side of an outer join" issue when trying to use with_for_update
# on the Hold object.
select_query = select(Hold).where(
Hold.position == 0,
Hold.end < utc_now(),
Expand All @@ -34,13 +52,17 @@ def remove_expired_holds_for_collection(
)

expired_holds = db.scalars(select_query).all()
expired_hold_events: [dict[str, Any]] = []
for hold in expired_holds:
analytics.collect_event(
library=hold.library,
license_pool=hold.license_pool,
event_type=CirculationEvent.CM_HOLD_EXPIRED,
patron=hold.patron,
expired_hold_events.append(
dict(
library=hold.library,
license_pool=hold.license_pool,
event_type=CirculationEvent.CM_HOLD_EXPIRED,
patron=hold.patron,
)
)

# delete the holds
query = (
delete(Hold)
Expand All @@ -53,11 +75,12 @@ def remove_expired_holds_for_collection(
.execution_options(synchronize_session="fetch")
)
result = db.execute(query)

# We need the type ignores here because result doesn't always have
# a rowcount, but the sqlalchemy docs swear it will in the case of
# a delete statement.
# https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#getting-affected-row-count-from-update-delete
return result.rowcount # type: ignore[attr-defined,no-any-return]
return result.rowcount, expired_hold_events # type: ignore[attr-defined,no-any-return]


def licensepool_ids_with_holds(
Expand Down Expand Up @@ -96,8 +119,7 @@ def lock_licenses(license_pool: LicensePool) -> None:
def recalculate_holds_for_licensepool(
license_pool: LicensePool,
reservation_period: datetime.timedelta,
analytics: Analytics,
) -> int:
) -> tuple[int, dict[str, Any]]:
# We take out row level locks on all the licenses and holds for this license pool, so that
# everything is in a consistent state while we update the hold queue. This means we should be
# quickly committing the transaction, to avoid contention or deadlocks.
Expand All @@ -111,6 +133,8 @@ def recalculate_holds_for_licensepool(
waiting = holds[reserved:]
updated = 0

events: [dict[str, Any]] = []

# These holds have a copy reserved for them.
for hold in ready:
# If this hold isn't already in position 0, the hold just became available.
Expand All @@ -119,11 +143,13 @@ def recalculate_holds_for_licensepool(
hold.position = 0
hold.end = utc_now() + reservation_period
updated += 1
analytics.collect_event(
library=hold.library,
license_pool=hold.license_pool,
event_type=CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT,
patron=hold.patron,
events.append(
dict(
library=hold.library,
license_pool=hold.license_pool,
event_type=CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT,
patron=hold.patron,
)
)

# Update the position for the remaining holds.
Expand All @@ -134,31 +160,45 @@ def recalculate_holds_for_licensepool(
hold.end = None
updated += 1

return updated
return updated, events


@shared_task(queue=QueueNames.default, bind=True)
def remove_expired_holds_for_collection_task(task: Task, collection_id: int) -> None:
"""
A shared task for removing expired holds from the database for a collection
"""
analytics = task.services.analytics.analytics()
with task.transaction() as session:
collection = Collection.by_id(session, collection_id)
removed, events = remove_expired_holds_for_collection(
session,
collection_id,
)
task.log.info(
f"Removed {removed} expired holds for collection {collection.name} ({collection_id})."
)

# publish events only after successful commit
for event in events:
analytics.collect_event(**event)


@shared_task(queue=QueueNames.default, bind=True)
def remove_expired_holds(task: Task) -> None:
"""
Remove expired holds from the database.
Issue remove expired hold tasks for eligible collections
"""
registry = task.services.integration_registry.license_providers()
protocols = registry.get_protocols(OPDS2WithODLApi, default=False)
analytics = task.services.analytics.analytics()
with task.session() as session:
collections = [
(collection.id, collection.name)
for collection in Collection.by_protocol(session, protocols)
if collection.id is not None
]
for collection_id, collection_name in collections:
with task.transaction() as session:
removed = remove_expired_holds_for_collection(
session, collection_id, analytics
)
task.log.info(
f"Removed {removed} expired holds for collection {collection_name} ({collection_id})."
)
remove_expired_holds_for_collection.delay(collection_id)


@shared_task(queue=QueueNames.default, bind=True)
Expand Down Expand Up @@ -191,6 +231,7 @@ def recalculate_hold_queue_collection(
Recalculate the hold queue for a collection.
"""
lock = _redis_lock_recalculate_holds(task.services.redis.client(), collection_id)
analytics = task.services.analytics.analytics()
with lock.lock() as locked:
if not locked:
task.log.info(
Expand Down Expand Up @@ -233,11 +274,9 @@ def recalculate_hold_queue_collection(
)
continue

analytics = task.services.analytics.analytics()
updated = recalculate_holds_for_licensepool(
updated, events = recalculate_holds_for_licensepool(
license_pool,
reservation_period,
analytics,
)
edition = license_pool.presentation_edition
title = edition.title if edition else None
Expand All @@ -247,6 +286,10 @@ def recalculate_hold_queue_collection(
f"{updated} holds out of date."
)

# fire events after successful database update
for event in events:
analytics.collect_event(**event)

if len(license_pool_ids) == batch_size:
# We are done this batch, but there is probably more work to do, we queue up the next batch.
raise task.replace(
Expand Down
14 changes: 5 additions & 9 deletions src/palace/manager/core/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,22 +880,21 @@ def run_once(self, *args, **kwargs):
count = qu.count()
self.log.info("Deleting %d row(s)", count)
while count > 0:
post_delete_ops = []
for i in qu.limit(self.BATCH_SIZE):
self.log.info("Deleting %r", i)
self.delete(i)
self.post_delete(i)
rows_deleted += 1

self._db.commit()

for op in post_delete_ops:
op()
self.after_commit()

count = qu.count()
return TimestampData(achievements="Items deleted: %d" % rows_deleted)

def delete(self, row):
def after_commit(self) -> None:
return None

def delete(self, row) -> None:
"""Delete a row from the database.
CAUTION: If you override this method such that it doesn't
Expand All @@ -904,9 +903,6 @@ def delete(self, row):
"""
self._db.delete(row)

def post_delete(self, row) -> None:
return None

def query(self):
return self._db.query(self.MODEL_CLASS).filter(self.where_clause)

Expand Down
Loading

0 comments on commit 8db8da1

Please sign in to comment.