From 4e2f2694a43e78edce6ea4bfbe817dbb18156c0f Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Thu, 4 Apr 2024 20:07:34 -0300 Subject: [PATCH] Celery queue prototype --- api/admin/controller/collection_settings.py | 2 + api/admin/controller/report.py | 55 ++-- bin/delete_old_deferred_tasks | 12 - bin/generate_inventory_reports | 12 - core/celery/__init__.py | 0 core/celery/job.py | 28 ++ core/celery/task.py | 26 ++ core/celery/tasks/__init__.py | 0 core/celery/tasks/collection_delete.py | 44 +++ core/celery/tasks/inventory_reports.py | 189 ++++++++++++ core/celery/worker.py | 35 +++ core/model/__init__.py | 9 +- core/scripts.py | 258 +---------------- core/service/__init__.py | 0 core/service/analytics/__init__.py | 0 core/service/celery/__init__.py | 0 core/service/celery/celery.py | 31 ++ core/service/celery/configuration.py | 28 ++ core/service/celery/container.py | 11 + core/service/container.py | 8 + core/service/email/__init__.py | 0 core/service/email/email.py | 7 +- core/service/logging/__init__.py | 0 core/service/logging/log.py | 9 + core/service/search/__init__.py | 0 core/service/storage/__init__.py | 0 core/tasks/__init__.py | 0 core/tasks/celery.py | 0 docker-compose.yml | 15 + docker/services/cron/cron.d/circulation | 4 - poetry.lock | 268 +++++++++++++++++- pyproject.toml | 25 +- .../api/admin/controller/test_collections.py | 13 +- tests/api/admin/controller/test_report.py | 37 ++- tests/core/celery/__init__.py | 0 tests/core/celery/tasks/__init__.py | 0 .../celery/tasks/test_collection_delete.py | 61 ++++ .../celery/tasks/test_inventory_reports.py | 142 ++++++++++ tests/core/conftest.py | 1 + tests/core/service/logging/test_log.py | 21 ++ tests/core/test_app_server.py | 17 +- tests/core/test_opds2_import.py | 2 + tests/core/test_scripts.py | 203 +------------ tests/fixtures/celery.py | 97 +++++++ tests/fixtures/database.py | 21 +- tests/fixtures/services.py | 21 ++ 46 files changed, 1154 insertions(+), 558 deletions(-) delete mode 100755 bin/delete_old_deferred_tasks delete mode 100755 bin/generate_inventory_reports create mode 100644 core/celery/__init__.py create mode 100644 core/celery/job.py create mode 100644 core/celery/task.py create mode 100644 core/celery/tasks/__init__.py create mode 100644 core/celery/tasks/collection_delete.py create mode 100644 core/celery/tasks/inventory_reports.py create mode 100644 core/celery/worker.py create mode 100644 core/service/__init__.py create mode 100644 core/service/analytics/__init__.py create mode 100644 core/service/celery/__init__.py create mode 100644 core/service/celery/celery.py create mode 100644 core/service/celery/configuration.py create mode 100644 core/service/celery/container.py create mode 100644 core/service/email/__init__.py create mode 100644 core/service/logging/__init__.py create mode 100644 core/service/search/__init__.py create mode 100644 core/service/storage/__init__.py create mode 100644 core/tasks/__init__.py create mode 100644 core/tasks/celery.py create mode 100644 tests/core/celery/__init__.py create mode 100644 tests/core/celery/tasks/__init__.py create mode 100644 tests/core/celery/tasks/test_collection_delete.py create mode 100644 tests/core/celery/tasks/test_inventory_reports.py create mode 100644 tests/fixtures/celery.py diff --git a/api/admin/controller/collection_settings.py b/api/admin/controller/collection_settings.py index bbf91a8264..7144c365b6 100644 --- a/api/admin/controller/collection_settings.py +++ b/api/admin/controller/collection_settings.py @@ -19,6 +19,7 @@ ) from api.circulation import CirculationApiType from api.integration.registry.license_providers import LicenseProvidersRegistry +from core.celery.tasks.collection_delete import collection_delete from core.integration.base import HasChildIntegrationConfiguration from core.integration.registry import IntegrationRegistry from core.model import ( @@ -169,6 +170,7 @@ def process_delete(self, service_id: int) -> Response | ProblemDetail: # Flag the collection to be deleted by script in the background. collection.marked_for_deletion = True + collection_delete.delay(collection.id) return Response("Deleted", 200) def process_collection_self_tests( diff --git a/api/admin/controller/report.py b/api/admin/controller/report.py index 8155216063..aafd9de046 100644 --- a/api/admin/controller/report.py +++ b/api/admin/controller/report.py @@ -1,19 +1,13 @@ import json -from dataclasses import asdict from http import HTTPStatus import flask from flask import Response from sqlalchemy.orm import Session +from core.celery.tasks.inventory_reports import generate_inventory_reports from core.model import Library from core.model.admin import Admin -from core.model.deferredtask import ( - DeferredTaskType, - InventoryReportTaskData, - queue_task, -) -from core.problem_details import INTERNAL_SERVER_ERROR from core.util.log import LoggerMixin from core.util.problem_detail import ProblemDetail @@ -25,29 +19,24 @@ def __init__(self, db: Session): def generate_inventory_report(self) -> Response | ProblemDetail: library: Library = getattr(flask.request, "library") admin: Admin = getattr(flask.request, "admin") - try: - # these values should never be None - assert admin.email - assert admin.id - assert library.id - - data: InventoryReportTaskData = InventoryReportTaskData( - admin_email=admin.email, admin_id=admin.id, library_id=library.id - ) - task, is_new = queue_task( - self._db, task_type=DeferredTaskType.INVENTORY_REPORT, data=asdict(data) - ) - - msg = ( - f"An inventory report request was {'already' if not is_new else ''} received at {task.created}. " - f"When processing is complete, the report will be sent to {admin.email}." - ) - - self.log.info(msg + f" {task}") - http_status = HTTPStatus.ACCEPTED if is_new else HTTPStatus.CONFLICT - return Response(json.dumps(dict(message=msg)), http_status) - except Exception as e: - msg = f"failed to generate inventory report request: {e}" - self.log.error(msg=msg, exc_info=e) - self._db.rollback() - return INTERNAL_SERVER_ERROR.detailed(detail=msg) + + # these values should never be None + assert admin.email + assert admin.id + assert library.id + + task = generate_inventory_reports.delay( + library_id=library.id, admin_email=admin.email + ) + + msg = ( + f"An inventory report request was received. " + f"When processing is complete, the report will be sent to {admin.email}." + ) + + self.log.info(msg + f" {task.id}") + return Response( + json.dumps(dict(message=msg)), + status=HTTPStatus.ACCEPTED, + mimetype="application/json", + ) diff --git a/bin/delete_old_deferred_tasks b/bin/delete_old_deferred_tasks deleted file mode 100755 index e2a4d811e6..0000000000 --- a/bin/delete_old_deferred_tasks +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python -"""Delete completed deferred tasks over 30 days old.""" -import os -import sys - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.scripts import DeleteOldDeferredTasks - -DeleteOldDeferredTasks().run() diff --git a/bin/generate_inventory_reports b/bin/generate_inventory_reports deleted file mode 100755 index 3fff5d7d07..0000000000 --- a/bin/generate_inventory_reports +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python -"""Update the cached sizes of all custom lists.""" -import os -import sys - -bin_dir = os.path.split(__file__)[0] -package_dir = os.path.join(bin_dir, "..") -sys.path.append(os.path.abspath(package_dir)) - -from core.scripts import GenerateInventoryReports - -GenerateInventoryReports().run() diff --git a/core/celery/__init__.py b/core/celery/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/celery/job.py b/core/celery/job.py new file mode 100644 index 0000000000..1ba5f0b1d9 --- /dev/null +++ b/core/celery/job.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Generator +from contextlib import contextmanager + +from sqlalchemy.orm import Session, sessionmaker + +from core.util.log import LoggerMixin + + +class Job(LoggerMixin, ABC): + def __init__(self, session_maker: sessionmaker[Session]): + self._session_maker = session_maker + + @contextmanager + def session(self) -> Generator[Session, None, None]: + with self._session_maker() as session: + yield session + + @contextmanager + def transaction(self) -> Generator[Session, None, None]: + with self._session_maker.begin() as session: + yield session + + @abstractmethod + def run(self) -> None: + ... diff --git a/core/celery/task.py b/core/celery/task.py new file mode 100644 index 0000000000..023b726f58 --- /dev/null +++ b/core/celery/task.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import celery +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import NullPool + +from core.model import SessionManager +from core.service.container import Services, container_instance +from core.util.log import LoggerMixin + + +class Task(celery.Task, LoggerMixin): + _session_maker = None + + @property + def session_maker(self) -> sessionmaker[Session]: + if self._session_maker is None: + engine = SessionManager.engine(poolclass=NullPool) + maker = sessionmaker(bind=engine) + SessionManager.setup_event_listener(maker) + self._session_maker = maker + return self._session_maker + + @property + def services(self) -> Services: + return container_instance() diff --git a/core/celery/tasks/__init__.py b/core/celery/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/celery/tasks/collection_delete.py b/core/celery/tasks/collection_delete.py new file mode 100644 index 0000000000..23c88cb23e --- /dev/null +++ b/core/celery/tasks/collection_delete.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from core.celery.job import Job +from core.celery.task import Task +from core.model import Collection + + +class CollectionDeleteJob(Job): + def __init__(self, session_maker: sessionmaker[Session], collection_id: int): + super().__init__(session_maker) + self.collection_id = collection_id + + @staticmethod + def collection(session: Session, collection_id: int) -> Collection | None: + return ( + session.execute(select(Collection).where(Collection.id == collection_id)) + .scalars() + .one_or_none() + ) + + @staticmethod + def collection_name(collection: Collection) -> str: + return f"{collection.name}/{collection.protocol} ({collection.id})" + + def run(self) -> None: + with self.transaction() as session: + collection = self.collection(session, self.collection_id) + if collection is None: + self.log.error( + f"Collection with id {self.collection_id} not found. Unable to delete." + ) + return + + self.log.info(f"Deleting collection {self.collection_name(collection)}") + collection.delete() + + +@shared_task(key="high", bind=True) +def collection_delete(task: Task, collection_id: int) -> None: + CollectionDeleteJob(task.session_maker, collection_id).run() diff --git a/core/celery/tasks/inventory_reports.py b/core/celery/tasks/inventory_reports.py new file mode 100644 index 0000000000..4ed1610612 --- /dev/null +++ b/core/celery/tasks/inventory_reports.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +"""Generate inventory reports from queued report tasks""" +import csv +import datetime +import os +import tempfile + +from celery import shared_task +from sqlalchemy import not_, select, text +from sqlalchemy.orm import Session, sessionmaker + +from api.integration.registry.license_providers import LicenseProvidersRegistry +from core.celery.job import Job +from core.celery.task import Task +from core.integration.goals import Goals +from core.model import IntegrationConfiguration, IntegrationLibraryConfiguration +from core.opds_import import OPDSImporterSettings +from core.service.email.email import SendEmailCallable + + +class InventoryReportsJob(Job): + def __init__( + self, + session_maker: sessionmaker[Session], + send_email: SendEmailCallable, + library_id: int, + admin_email: str, + ): + super().__init__(session_maker=session_maker) + self.library_id = library_id + self.admin_email = admin_email + self.send_email = send_email + + @staticmethod + def query() -> str: + return """ + select + e.title, + e.author, + i.identifier, + e.language, + e.publisher, + e.medium as format, + ic.name collection_name, + DATE_PART('day', l.expires::date) - DATE_PART('day',lp.availability_time::date) as license_duration_days, + l.expires license_expiration_date, + l.checkouts_available initial_loan_count, + (l.checkouts_available-l.checkouts_left) consumed_loans, + l.checkouts_left remaining_loans, + l.terms_concurrency allowed_concurrent_users, + coalesce(lib_holds.active_hold_count, 0) library_active_hold_count, + coalesce(lib_loans.active_loan_count, 0) library_active_loan_count, + CASE WHEN collection_sharing.is_shared_collection THEN lp.patrons_in_hold_queue + ELSE -1 + END shared_active_hold_count, + CASE WHEN collection_sharing.is_shared_collection THEN lp.licenses_reserved + ELSE -1 + END shared_active_loan_count + from datasources d, + collections c, + integration_configurations ic, + integration_library_configurations il, + libraries lib, + editions e, + identifiers i, + (select ic.parent_id, + count(ic.parent_id) > 1 is_shared_collection + from integration_library_configurations ic, + integration_configurations i, + collections c + where c.integration_configuration_id = i.id and + i.id = ic.parent_id group by ic.parent_id) collection_sharing, + licensepools lp left outer join licenses l on lp.id = l.license_pool_id + left outer join (select h.license_pool_id, + p.library_id, + count(h.id) active_hold_count + from holds h, + patrons p, + libraries l + where p.id = h.patron_id and + p.library_id = l.id and + l.id = :library_id + group by p.library_id, h.license_pool_id) lib_holds on lp.id = lib_holds.license_pool_id + left outer join (select ln.license_pool_id, + p.library_id, + count(ln.id) active_loan_count + from loans ln, + patrons p, + libraries l + where p.id = ln.patron_id and + p.library_id = l.id and + l.id = :library_id + group by p.library_id, ln.license_pool_id) lib_loans on lp.id = lib_holds.license_pool_id + where lp.identifier_id = i.id and + e.primary_identifier_id = i.id and + d.id = e.data_source_id and + c.id = lp.collection_id and + c.integration_configuration_id = ic.id and + ic.id = il.parent_id and + ic.id = collection_sharing.parent_id and + il.library_id = lib.id and + d.name = :data_source_name and + lib.id = :library_id + order by title, author + """ + + def generate_report(self, session: Session, data_source_name: str) -> str: + """Generate a csv file and return the file path""" + with tempfile.NamedTemporaryFile( + "w", + delete=False, + ) as temp: + writer = csv.writer(temp, delimiter=",") + rows = session.execute( + text(self.query()), + {"library_id": self.library_id, "data_source_name": data_source_name}, + ) + writer.writerow(rows.keys()) + writer.writerows(rows) + return temp.name + + def generate_inventory_reports(self, session: Session) -> None: + files = [] + try: + current_time = datetime.datetime.now() + date_str = current_time.strftime("%Y-%m-%d_%H:%M:%s") + attachments = {} + + integrations = session.scalars( + select(IntegrationConfiguration) + .join(IntegrationLibraryConfiguration) + .where( + IntegrationLibraryConfiguration.library_id == self.library_id, + IntegrationConfiguration.goal == Goals.LICENSE_GOAL, + not_( + IntegrationConfiguration.settings_dict.contains( + {"include_in_inventory_report": False} + ) + ), + ) + ).all() + + registry = LicenseProvidersRegistry() + data_source_names = [] + for integration in integrations: + settings = registry[integration.protocol].settings_load(integration) + if not isinstance(settings, OPDSImporterSettings): + continue + data_source_names.append(settings.data_source) + + for data_source_name in data_source_names: + formatted_ds_name = data_source_name.lower().replace(" ", "_") + file_name = f"palace-inventory-report-{formatted_ds_name}-{date_str}" + # generate csv file + file_path = self.generate_report( + session, + data_source_name=data_source_name, + ) + # extract contents of files and prepare in a dictionary of email attachments + with open(file_path) as csv_file: + attachments[f"{file_name}.csv"] = csv_file.read() + files.append(csv_file) + + self.send_email( + subject=f"Inventory Report {current_time}", + receivers=[self.admin_email], + text="", + attachments=attachments, + ) + finally: + for file in files: + os.remove(file.name) + + def run(self) -> None: + self.log.info("Generating inventory report") + with self.session() as session: + self.generate_inventory_reports(session) + + +@shared_task(key="high", bind=True) +def generate_inventory_reports(task: Task, library_id: int, admin_email: str) -> None: + task.log.error("Generating inventory reports") + InventoryReportsJob( + session_maker=task.session_maker, + send_email=task.services.email.send_email, + library_id=library_id, + admin_email=admin_email, + ).run() diff --git a/core/celery/worker.py b/core/celery/worker.py new file mode 100644 index 0000000000..400e796c6e --- /dev/null +++ b/core/celery/worker.py @@ -0,0 +1,35 @@ +# This file provides the entry point for the Celery worker. When running the Celery worker from the +# command line. You can use the following command: +# celery -A "core.celery.worker.app" worker + +import importlib +from pathlib import Path +from typing import Any + +from celery.signals import setup_logging + +from core.service.container import container_instance + + +def import_celery_tasks() -> None: + tasks_path = Path(__file__).parent / "tasks" + for task_file in tasks_path.glob("*.py"): + if task_file.stem == "__init__": + continue + module = f"core.celery.tasks.{task_file.stem}" + importlib.import_module(module) + + +@setup_logging.connect +def celery_logger_setup( + loglevel: int, logfile: str, format: str, colorize: bool, **kwargs: Any +) -> None: + # Override the default Celery logger setup to use the logger configuration from the service container, + # this will likely need to be updated so that we respect some of the Celery specific configuration options. + ... + + +services = container_instance() +services.init_resources() +import_celery_tasks() +app = services.celery.app() diff --git a/core/model/__init__.py b/core/model/__init__.py index cdee105037..0b47cd62f7 100644 --- a/core/model/__init__.py +++ b/core/model/__init__.py @@ -16,6 +16,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.pool import Pool from sqlalchemy.sql import compiler, select from sqlalchemy.sql.expression import literal_column, table @@ -336,10 +337,14 @@ class SessionManager: RECURSIVE_EQUIVALENTS_FUNCTION = "recursive_equivalents.sql" @classmethod - def engine(cls, url=None): + def engine(cls, url: str | None = None, poolclass: type[Pool] | None = None): url = url or Configuration.database_url() return create_engine( - url, echo=DEBUG, json_serializer=json_serializer, pool_pre_ping=True + url, + echo=DEBUG, + json_serializer=json_serializer, + pool_pre_ping=True, + poolclass=poolclass, ) @classmethod diff --git a/core/scripts.py b/core/scripts.py index 37a3ba76fa..dc548d607c 100644 --- a/core/scripts.py +++ b/core/scripts.py @@ -1,12 +1,10 @@ import argparse -import csv import datetime import json import logging import os import random import sys -import tempfile import traceback import unicodedata import uuid @@ -14,7 +12,7 @@ from enum import Enum from typing import TextIO -from sqlalchemy import and_, exists, not_, or_, select, text, tuple_ +from sqlalchemy import and_, exists, or_, select, tuple_ from sqlalchemy.orm import Query, Session, defer from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound @@ -34,7 +32,6 @@ Edition, Identifier, IntegrationConfiguration, - IntegrationLibraryConfiguration, Library, LicensePool, LicensePoolDeliveryMechanism, @@ -50,22 +47,11 @@ production_session, ) from core.model.classification import Classification -from core.model.deferredtask import ( - DeferredTask, - DeferredTaskType, - InventoryReportTaskData, - start_next_task, -) from core.model.devicetokens import DeviceToken, DeviceTokenTypes from core.model.listeners import site_configuration_has_changed from core.model.patron import Loan from core.monitor import CollectionMonitor, ReaperMonitor -from core.opds_import import ( - OPDSAPI, - OPDSImporter, - OPDSImporterSettings, - OPDSImportMonitor, -) +from core.opds_import import OPDSAPI, OPDSImporter, OPDSImportMonitor from core.query.customlist import CustomListQueries from core.search.coverage_provider import SearchIndexCoverageProvider from core.search.coverage_remover import RemovesSearchCoverage @@ -2772,246 +2758,6 @@ def suppress_work(self, library: Library, identifier: Identifier) -> None: self._db.commit() -class GenerateInventoryReports(Script): - """Generate inventory reports from queued report tasks""" - - @classmethod - def arg_parser(cls, _db: Session | None) -> argparse.ArgumentParser: # type: ignore[override] - parser = argparse.ArgumentParser() - if _db is None: - raise ValueError("No database session provided.") - - return parser - - @classmethod - def parse_command_line( - cls, _db: Session | None = None, cmd_args: list[str] | None = None - ): - parser = cls.arg_parser(_db) - return parser.parse_known_args(cmd_args)[0] - - def do_run(self, cmd_args: list[str] | None = None) -> None: - parsed = self.parse_command_line(self._db, cmd_args=cmd_args) - - while True: - task = start_next_task(self._db, DeferredTaskType.INVENTORY_REPORT) - if not task: - break - - self.process_task(task) - - self.remove_old_tasks() - - def process_task(self, task: DeferredTask): - data = InventoryReportTaskData(**task.data) - files = [] - try: - current_time = datetime.datetime.now() - date_str = current_time.strftime("%Y-%m-%d_%H:%M:%s") - attachments = {} - - integrations = self._db.scalars( - select(IntegrationConfiguration) - .join(IntegrationLibraryConfiguration) - .where( - IntegrationLibraryConfiguration.library_id == data.library_id, - IntegrationConfiguration.goal == Goals.LICENSE_GOAL, - not_( - IntegrationConfiguration.settings_dict.contains( - {"include_in_inventory_report": False} - ) - ), - ) - ).all() - - registry = LicenseProvidersRegistry() - data_source_names = [] - for integration in integrations: - settings = registry[integration.protocol].settings_load(integration) - if not isinstance(settings, OPDSImporterSettings): - continue - data_source_names.append(settings.data_source) - - for data_source_name in data_source_names: - formatted_ds_name = data_source_name.lower().replace(" ", "_") - file_name = f"palace-inventory-report-{formatted_ds_name}-{date_str}" - # generate csv file - file_path = self.generate_report( - data_source_name=data_source_name, - library_id=data.library_id, - ) - # extract contents of files and prepare in a dictionary of email attachments - with open(file_path) as csv_file: - attachments[f"{file_name}.csv"] = csv_file.read() - files.append(csv_file) - - self.services.email.send_email( - subject=f"Inventory Report {current_time}", - receivers=[data.admin_email], - text="", - attachments=attachments, - ) - task.complete() - except Exception as e: - # log error - self.log.error(f"Failed to process task: {task}", e) - task.fail(failure_details=f"{e}") - finally: - self._db.commit() - for file in files: - os.remove(file.name) - - def generate_report(self, data_source_name: str, library_id: int) -> str: - """Generate a csv file and return the file path""" - with tempfile.NamedTemporaryFile( - "w", - delete=False, - ) as temp: - writer = csv.writer(temp, delimiter=",") - rows = self._db.execute( - text(self.inventory_report_query()), - {"library_id": library_id, "data_source_name": data_source_name}, - ) - writer.writerow(rows.keys()) - writer.writerows(rows) - return temp.name - - def inventory_report_query(self) -> str: - return """ - select - e.title, - e.author, - i.identifier, - e.language, - e.publisher, - e.medium as format, - ic.name collection_name, - DATE_PART('day', l.expires::date) - DATE_PART('day',lp.availability_time::date) as license_duration_days, - l.expires license_expiration_date, - l.checkouts_available initial_loan_count, - (l.checkouts_available-l.checkouts_left) consumed_loans, - l.checkouts_left remaining_loans, - l.terms_concurrency allowed_concurrent_users, - coalesce(lib_holds.active_hold_count, 0) library_active_hold_count, - coalesce(lib_loans.active_loan_count, 0) library_active_loan_count, - CASE WHEN collection_sharing.is_shared_collection THEN lp.patrons_in_hold_queue - ELSE -1 - END shared_active_hold_count, - CASE WHEN collection_sharing.is_shared_collection THEN lp.licenses_reserved - ELSE -1 - END shared_active_loan_count - from datasources d, - collections c, - integration_configurations ic, - integration_library_configurations il, - libraries lib, - editions e, - identifiers i, - (select ic.parent_id, - count(ic.parent_id) > 1 is_shared_collection - from integration_library_configurations ic, - integration_configurations i, - collections c - where c.integration_configuration_id = i.id and - i.id = ic.parent_id group by ic.parent_id) collection_sharing, - licensepools lp left outer join licenses l on lp.id = l.license_pool_id - left outer join (select h.license_pool_id, - p.library_id, - count(h.id) active_hold_count - from holds h, - patrons p, - libraries l - where p.id = h.patron_id and - p.library_id = l.id and - l.id = :library_id - group by p.library_id, h.license_pool_id) lib_holds on lp.id = lib_holds.license_pool_id - left outer join (select ln.license_pool_id, - p.library_id, - count(ln.id) active_loan_count - from loans ln, - patrons p, - libraries l - where p.id = ln.patron_id and - p.library_id = l.id and - l.id = :library_id - group by p.library_id, ln.license_pool_id) lib_loans on lp.id = lib_holds.license_pool_id - where lp.identifier_id = i.id and - e.primary_identifier_id = i.id and - d.id = e.data_source_id and - c.id = lp.collection_id and - c.integration_configuration_id = ic.id and - ic.id = il.parent_id and - ic.id = collection_sharing.parent_id and - il.library_id = lib.id and - d.name = :data_source_name and - lib.id = :library_id - order by title, author - """ - - def remove_old_tasks(self): - """Remove inventory generation tasks older than 30 days""" - self._db.query(DeferredTask) - thirty_days_ago = utc_now() - datetime.timedelta(days=30) - tasks = ( - self._db.query(DeferredTask) - .filter(DeferredTask.task_type == DeferredTaskType.INVENTORY_REPORT) - .filter(DeferredTask.processing_end_time < thirty_days_ago) - ) - for task in tasks: - self._db.delete(task) - - -class DeleteOldDeferredTasks(Script): - """Delete old deferred tasks.""" - - @classmethod - def arg_parser(cls, _db: Session | None) -> argparse.ArgumentParser: # type: ignore[override] - parser = argparse.ArgumentParser() - if _db is None: - raise ValueError("No database session provided.") - - return parser - - @classmethod - def parse_command_line( - cls, _db: Session | None = None, cmd_args: list[str] | None = None - ): - parser = cls.arg_parser(_db) - return parser.parse_known_args(cmd_args)[0] - - def do_run(self, cmd_args: list[str] | None = None) -> None: - parsed = self.parse_command_line(self._db, cmd_args=cmd_args) - self.remove_old_tasks() - - def remove_old_tasks(self): - """Remove inventory generation tasks older than 30 days""" - self._db.query(DeferredTask) - days = 30 - thirty_days_ago = utc_now() - datetime.timedelta(days=days) - tasks = ( - self._db.query(DeferredTask) - .filter(DeferredTask.task_type == DeferredTaskType.INVENTORY_REPORT) - .filter(DeferredTask.processing_end_time < thirty_days_ago) - ) - - tasks_removed = 0 - - for task in tasks: - self._db.delete(task) - tasks_removed += 1 - - self._db.commit() - if tasks_removed > 0: - self.log.info( - f"Successfully removed {tasks_removed} task{ 's' if tasks_removed > 1 else ''} " - f"that were completed over {days} days ago." - ) - else: - self.log.info( - f"There were no deferred tasks that were completed over {days} ago to be removed." - ) - - class MockStdin: """Mock a list of identifiers passed in on standard input.""" diff --git a/core/service/__init__.py b/core/service/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/service/analytics/__init__.py b/core/service/analytics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/service/celery/__init__.py b/core/service/celery/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/service/celery/celery.py b/core/service/celery/celery.py new file mode 100644 index 0000000000..ceeaa6dd0e --- /dev/null +++ b/core/service/celery/celery.py @@ -0,0 +1,31 @@ +from typing import Any + +from celery import Celery +from kombu import Exchange, Queue + + +def task_queue_config(cm_name: str) -> dict[str, Any]: + high_prefixed = f"{cm_name}:high" + default_prefixed = f"{cm_name}:default" + + return { + "task_queues": ( + Queue(high_prefixed, Exchange(high_prefixed), routing_key="high"), + Queue(default_prefixed, Exchange(default_prefixed), routing_key="default"), + ), + "task_default_queue": default_prefixed, + "task_default_exchange": default_prefixed, + "task_default_routing_key": "default", + } + + +def celery_factory(config: dict[str, Any]) -> Celery: + # Create a new Celery app + cm_name = config.get("cm_name") + assert isinstance(cm_name, str) + app = Celery(cm_name, task_cls="core.celery.task:Task") + app.conf.update(config) + app.conf.update(task_queue_config(cm_name)) + app.set_default() + + return app diff --git a/core/service/celery/configuration.py b/core/service/celery/configuration.py new file mode 100644 index 0000000000..5cc734f114 --- /dev/null +++ b/core/service/celery/configuration.py @@ -0,0 +1,28 @@ +from pydantic import RedisDsn + +from core.service.configuration import ServiceConfiguration + + +class CeleryConfiguration(ServiceConfiguration): + # All the settings here are named following the Celery configuration, so we can + # easily pass them into the Celery app. You can find more details about any of + # these settings in the Celery documentation. + # https://docs.celeryq.dev/en/stable/userguide/configuration.html + broker_url: RedisDsn + broker_connection_retry_on_startup: bool = True + + task_acks_late = True + task_reject_on_worker_lost = True + task_remote_tracebacks = True + task_create_missing_queues = False + + worker_cancel_long_running_tasks_on_connection_loss: bool = False + worker_max_tasks_per_child: int = 100 + worker_prefetch_multiplier: int = 1 + worker_hijack_root_logger: bool = False + worker_log_color: bool = False + + cm_name: str = "palace" + + class Config: + env_prefix = "PALACE_CELERY_" diff --git a/core/service/celery/container.py b/core/service/celery/container.py new file mode 100644 index 0000000000..4abd02bb83 --- /dev/null +++ b/core/service/celery/container.py @@ -0,0 +1,11 @@ +from celery import Celery +from dependency_injector import providers +from dependency_injector.containers import DeclarativeContainer + +from core.service.celery.celery import celery_factory + + +class CeleryContainer(DeclarativeContainer): + config = providers.Configuration() + + app: providers.Provider[Celery] = providers.Resource(celery_factory, config=config) diff --git a/core/service/container.py b/core/service/container.py index 55f7e4a37c..8370ab9a8d 100644 --- a/core/service/container.py +++ b/core/service/container.py @@ -4,6 +4,8 @@ from core.service.analytics.configuration import AnalyticsConfiguration from core.service.analytics.container import AnalyticsContainer +from core.service.celery.configuration import CeleryConfiguration +from core.service.celery.container import CeleryContainer from core.service.email.configuration import EmailConfiguration from core.service.email.container import Email from core.service.logging.configuration import LoggingConfiguration @@ -44,6 +46,11 @@ class Services(DeclarativeContainer): config=config.email, ) + celery = Container( + CeleryContainer, + config=config.celery, + ) + def wire_container(container: Services) -> None: container.wire( @@ -75,6 +82,7 @@ def create_container() -> Services: "analytics": AnalyticsConfiguration().dict(), "search": SearchConfiguration().dict(), "email": EmailConfiguration().dict(), + "celery": CeleryConfiguration().dict(), } ) wire_container(container) diff --git a/core/service/email/__init__.py b/core/service/email/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/service/email/email.py b/core/service/email/email.py index b7ad1e131b..2f11269767 100644 --- a/core/service/email/email.py +++ b/core/service/email/email.py @@ -1,4 +1,5 @@ import os +from collections.abc import Mapping from email.message import EmailMessage from typing import Any, Protocol @@ -34,7 +35,7 @@ def send_email( receivers: list[str] | str, html: str | None = None, text: str | None = None, - attachments: dict[str, str | os.PathLike[Any] | bytes] | None = None, + attachments: Mapping[str, str | os.PathLike[Any] | bytes] | None = None, ) -> EmailMessage: return emailer.send( subject=subject, @@ -42,7 +43,7 @@ def send_email( receivers=receivers, text=text, html=html, - attachments=attachments, + attachments=attachments, # type: ignore[arg-type] ) @@ -54,6 +55,6 @@ def __call__( receivers: list[str] | str, html: str | None = None, text: str | None = None, - attachments: dict[str, str | os.PathLike[Any] | bytes] | None = None, + attachments: Mapping[str, str | os.PathLike[Any] | bytes] | None = None, ) -> EmailMessage: ... diff --git a/core/service/logging/__init__.py b/core/service/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/service/logging/log.py b/core/service/logging/log.py index fdb03bbd34..f4befeae17 100644 --- a/core/service/logging/log.py +++ b/core/service/logging/log.py @@ -7,6 +7,7 @@ from logging import Handler from typing import TYPE_CHECKING, Any +from celery._state import get_current_task from watchtower import CloudWatchLogHandler from core.service.logging.configuration import LogLevel @@ -68,6 +69,14 @@ def ensure_str(s: Any) -> Any: ) if record.exc_info: data["traceback"] = self.formatException(record.exc_info) + + # Handle the case where we're running in a Celery task, this information is usually captured by + # the Celery log formatter, but we are not using that formatter for our code. + # See https://docs.celeryq.dev/en/stable/reference/celery.app.log.html#celery.app.log.TaskFormatter + if task := get_current_task(): + data["task_id"] = task.request.id + data["task_name"] = task.name + return json.dumps(data) diff --git a/core/service/search/__init__.py b/core/service/search/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/service/storage/__init__.py b/core/service/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/tasks/__init__.py b/core/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/tasks/celery.py b/core/tasks/celery.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docker-compose.yml b/docker-compose.yml index 36c07522d5..124600e488 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,6 +17,8 @@ x-cm-variables: &cm PALACE_SECRET_KEY: "SECRET_KEY_USED_FOR_ADMIN_UI_COOKIES" PALACE_PATRON_WEB_HOSTNAMES: "*" PALACE_BASE_URL: "http://localhost:6500" + PALACE_CELERY_BROKER_URL: "redis://redis:6379/0" + PALACE_CELERY_CM_NAME: "test" # Set up the environment variables used for testing as well SIMPLIFIED_TEST_DATABASE: "postgresql://palace:test@pg:5432/circ" @@ -31,6 +33,8 @@ x-cm-variables: &cm condition: service_healthy os: condition: service_healthy + redis: + condition: service_healthy x-cm-build: &cm-build context: . @@ -93,8 +97,19 @@ services: environment: discovery.type: "single-node" DISABLE_SECURITY_PLUGIN: "true" + bootstrap.memory_lock: "true" + OPENSEARCH_JAVA_OPTS: "-Xms512m -Xmx512m" + DISABLE_INSTALL_DEMO_CONFIG: "true" healthcheck: test: curl --silent http://localhost:9200 >/dev/null; if [[ $$? == 52 ]]; then echo 0; else echo 1; fi interval: 30s timeout: 10s retries: 5 + + redis: + image: "redis:7" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 20s + retries: 3 diff --git a/docker/services/cron/cron.d/circulation b/docker/services/cron/cron.d/circulation index be0928cd70..46c6bd2f2c 100644 --- a/docker/services/cron/cron.d/circulation +++ b/docker/services/cron/cron.d/circulation @@ -112,7 +112,3 @@ HOME=/var/www/circulation 0 8,20 * * * root core/bin/run playtime_summation >> /var/log/cron.log 2>&1 # On the 2nd of every month 0 4 2 * * root core/bin/run playtime_reporting >> /var/log/cron.log 2>&1 - -# check the inventory report task queue every 15 minutes. -*/15 * * * * root core/bin/run generate_inventory_reports >> /var/log/cron.log 2>&1 -0 0 * * * root core/bin/run delete_old_deferred_tasks >> /var/log/cron.log 2>&1 diff --git a/poetry.lock b/poetry.lock index 004c2210b2..f22c3cf544 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "alembic" @@ -19,6 +19,31 @@ typing-extensions = ">=4" [package.extras] tz = ["backports.zoneinfo"] +[[package]] +name = "amqp" +version = "5.2.0" +description = "Low-level AMQP client for Python (fork of amqplib)." +optional = false +python-versions = ">=3.6" +files = [ + {file = "amqp-5.2.0-py3-none-any.whl", hash = "sha256:827cb12fb0baa892aad844fd95258143bce4027fdac4fccddbc43330fd281637"}, + {file = "amqp-5.2.0.tar.gz", hash = "sha256:a1ecff425ad063ad42a486c902807d1482311481c8ad95a72694b2975e75f7fd"}, +] + +[package.dependencies] +vine = ">=5.0.0,<6.0.0" + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + [[package]] name = "attrs" version = "23.1.0" @@ -124,6 +149,17 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "billiard" +version = "4.2.0" +description = "Python multiprocessing fork with improvements and bugfixes" +optional = false +python-versions = ">=3.7" +files = [ + {file = "billiard-4.2.0-py3-none-any.whl", hash = "sha256:07aa978b308f334ff8282bd4a746e681b3513db5c9a514cbdd810cbbdc19714d"}, + {file = "billiard-4.2.0.tar.gz", hash = "sha256:9a3c3184cb275aa17a732f93f65b20c525d3d9f253722d26a82194803ade5a2c"}, +] + [[package]] name = "blinker" version = "1.6.2" @@ -623,6 +659,63 @@ files = [ {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, ] +[[package]] +name = "celery" +version = "5.3.6" +description = "Distributed Task Queue." +optional = false +python-versions = ">=3.8" +files = [ + {file = "celery-5.3.6-py3-none-any.whl", hash = "sha256:9da4ea0118d232ce97dff5ed4974587fb1c0ff5c10042eb15278487cdd27d1af"}, + {file = "celery-5.3.6.tar.gz", hash = "sha256:870cc71d737c0200c397290d730344cc991d13a057534353d124c9380267aab9"}, +] + +[package.dependencies] +billiard = ">=4.2.0,<5.0" +click = ">=8.1.2,<9.0" +click-didyoumean = ">=0.3.0" +click-plugins = ">=1.1.1" +click-repl = ">=0.2.0" +kombu = ">=5.3.4,<6.0" +python-dateutil = ">=2.8.2" +redis = {version = ">=4.5.2,<4.5.5 || >4.5.5,<6.0.0", optional = true, markers = "extra == \"redis\""} +tblib = {version = ">=1.5.0", optional = true, markers = "python_version >= \"3.8.0\" and extra == \"tblib\""} +tzdata = ">=2022.7" +vine = ">=5.1.0,<6.0" + +[package.extras] +arangodb = ["pyArango (>=2.0.2)"] +auth = ["cryptography (==41.0.5)"] +azureblockblob = ["azure-storage-blob (>=12.15.0)"] +brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"] +cassandra = ["cassandra-driver (>=3.25.0,<4)"] +consul = ["python-consul2 (==0.1.5)"] +cosmosdbsql = ["pydocumentdb (==2.3.5)"] +couchbase = ["couchbase (>=3.0.0)"] +couchdb = ["pycouchdb (==1.14.2)"] +django = ["Django (>=2.2.28)"] +dynamodb = ["boto3 (>=1.26.143)"] +elasticsearch = ["elastic-transport (<=8.10.0)", "elasticsearch (<=8.11.0)"] +eventlet = ["eventlet (>=0.32.0)"] +gevent = ["gevent (>=1.5.0)"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +memcache = ["pylibmc (==1.6.3)"] +mongodb = ["pymongo[srv] (>=4.0.2)"] +msgpack = ["msgpack (==1.0.7)"] +pymemcache = ["python-memcached (==1.59)"] +pyro = ["pyro4 (==4.82)"] +pytest = ["pytest-celery (==0.0.0)"] +redis = ["redis (>=4.5.2,!=4.5.5,<6.0.0)"] +s3 = ["boto3 (>=1.26.143)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +solar = ["ephem (==4.1.5)"] +sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] +sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=1.3.1)"] +zstd = ["zstandard (==0.22.0)"] + [[package]] name = "certifi" version = "2024.2.2" @@ -746,6 +839,55 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "click-didyoumean" +version = "0.3.1" +description = "Enables git-like *did-you-mean* feature in click" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c"}, + {file = "click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463"}, +] + +[package.dependencies] +click = ">=7" + +[[package]] +name = "click-plugins" +version = "1.1.1" +description = "An extension module for click to enable registering CLI commands via setuptools entry-points." +optional = false +python-versions = "*" +files = [ + {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"}, + {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"}, +] + +[package.dependencies] +click = ">=4.0" + +[package.extras] +dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"] + +[[package]] +name = "click-repl" +version = "0.3.0" +description = "REPL plugin for Click" +optional = false +python-versions = ">=3.6" +files = [ + {file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"}, + {file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"}, +] + +[package.dependencies] +click = ">=7.0" +prompt-toolkit = ">=3.0.36" + +[package.extras] +testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] + [[package]] name = "colorama" version = "0.4.6" @@ -1878,6 +2020,38 @@ files = [ cryptography = ">=3.4" typing-extensions = ">=4.5.0" +[[package]] +name = "kombu" +version = "5.3.6" +description = "Messaging library for Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "kombu-5.3.6-py3-none-any.whl", hash = "sha256:49f1e62b12369045de2662f62cc584e7df83481a513db83b01f87b5b9785e378"}, + {file = "kombu-5.3.6.tar.gz", hash = "sha256:f3da5b570a147a5da8280180aa80b03807283d63ea5081fcdb510d18242431d9"}, +] + +[package.dependencies] +amqp = ">=5.1.1,<6.0.0" +vine = "*" + +[package.extras] +azureservicebus = ["azure-servicebus (>=7.10.0)"] +azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"] +confluentkafka = ["confluent-kafka (>=2.2.0)"] +consul = ["python-consul2"] +librabbitmq = ["librabbitmq (>=2.0.0)"] +mongodb = ["pymongo (>=4.1.1)"] +msgpack = ["msgpack"] +pyro = ["pyro4"] +qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] +redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] +slmq = ["softlayer-messaging (>=1.0.3)"] +sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] +sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +yaml = ["PyYAML (>=3.10)"] +zookeeper = ["kazoo (>=2.8.0)"] + [[package]] name = "levenshtein" version = "0.25.0" @@ -2761,6 +2935,20 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "prompt-toolkit" +version = "3.0.43" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.43-py3-none-any.whl", hash = "sha256:a11a29cb3bf0a28a387fe5122cdb649816a957cd9261dcedf8c9f1fef33eacf6"}, + {file = "prompt_toolkit-3.0.43.tar.gz", hash = "sha256:3527b7af26106cbc65a040bcc84839a3566ec1b051bb0bfe953631e704b0ff7d"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "proto-plus" version = "1.22.2" @@ -3329,6 +3517,20 @@ alembic = "*" pytest = ">=6.0" sqlalchemy = "*" +[[package]] +name = "pytest-celery" +version = "0.0.0" +description = "pytest-celery a shim pytest plugin to enable celery.contrib.pytest" +optional = false +python-versions = "*" +files = [ + {file = "pytest-celery-0.0.0.tar.gz", hash = "sha256:cfd060fc32676afa1e4f51b2938f903f7f75d952186b8c6cf631628c4088f406"}, + {file = "pytest_celery-0.0.0-py2.py3-none-any.whl", hash = "sha256:63dec132df3a839226ecb003ffdbb0c2cb88dd328550957e979c942766578060"}, +] + +[package.dependencies] +celery = ">=4.4.0" + [[package]] name = "pytest-cov" version = "5.0.0" @@ -3619,6 +3821,24 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "redis" +version = "5.0.3" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.3-py3-none-any.whl", hash = "sha256:5da9b8fe9e1254293756c16c008e8620b3d15fcc6dde6babde9541850e72a32d"}, + {file = "redis-5.0.3.tar.gz", hash = "sha256:4973bae7444c0fbed64a06b87446f79361cb7e4ec1538c022d696ed7a5015580"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "redmail" version = "0.6.0" @@ -4049,6 +4269,17 @@ files = [ [package.dependencies] typing-extensions = ">=3.7.4" +[[package]] +name = "tblib" +version = "3.0.0" +description = "Traceback serialization library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"}, + {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"}, +] + [[package]] name = "textblob" version = "0.18.0.post0" @@ -4325,6 +4556,17 @@ files = [ {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] +[[package]] +name = "tzdata" +version = "2024.1" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + [[package]] name = "unicodecsv" version = "0.14.1" @@ -4372,6 +4614,17 @@ files = [ {file = "uwsgi-2.0.24.tar.gz", hash = "sha256:77b6dd5cd633f4ae87ee393f7701f617736815499407376e78f3d16467523afe"}, ] +[[package]] +name = "vine" +version = "5.1.0" +description = "Python promises." +optional = false +python-versions = ">=3.6" +files = [ + {file = "vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc"}, + {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, +] + [[package]] name = "virtualenv" version = "20.25.0" @@ -4419,6 +4672,17 @@ files = [ {file = "wcag-contrast-ratio-0.9.tar.gz", hash = "sha256:69192b8e5c0a7d0dc5ff1187eeb3e398141633a4bde51c69c87f58fe87ed361c"}, ] +[[package]] +name = "wcwidth" +version = "0.2.13" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, +] + [[package]] name = "websocket-client" version = "1.5.0" @@ -4561,4 +4825,4 @@ lxml = ">=3.8" [metadata] lock-version = "2.0" python-versions = ">=3.10,<4" -content-hash = "b7a9561e27c582de49b11c725abd5244efc3d1499c32a74ab19813935b0f55b6" +content-hash = "66490f35ee8f9a8cfac0908bca3737a73a34e92ed8131443e52c1a595cde2d45" diff --git a/pyproject.toml b/pyproject.toml index 4ff1a16100..4a1f131375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,6 @@ module = [ "core.migration.*", "core.model.announcements", "core.model.collection", - "core.model.deferredtask", "core.model.hassessioncache", "core.model.integration", "core.model.library", @@ -123,6 +122,26 @@ strict_equality = true warn_return_any = true warn_unused_ignores = true +[[tool.mypy.overrides]] +# Custom mypy overrides for the core.celery module. +# Since Celery is untyped, and we're using a lot of its +# features, we disable some of the strict mypy checks +# that get annoying when working with Celery. +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_subclassing_any = false +disallow_untyped_decorators = false +disallow_untyped_defs = true +module = [ + "core.celery.*", +] +no_implicit_reexport = true +strict_concatenate = true +strict_equality = true +warn_return_any = true +warn_unused_ignores = true + [[tool.mypy.overrides]] # This override silences errors for modules in our own codebase that we import # from other covered modules. Ideally we will be able to remove this override @@ -139,6 +158,7 @@ module = [ ignore_missing_imports = true module = [ "aws_xray_sdk.ext.*", + "celery.*", # This is ignored because the file is created when building a container # so it typically doesn't exist when running mypy, but since it only # contains a couple version strings it can be safely ignored @@ -154,6 +174,7 @@ module = [ "html_sanitizer", "isbnlib", "jwcrypto", + "kombu", "lxml.*", "money", "multipledispatch", @@ -182,6 +203,7 @@ alembic = {extras = ["tz"], version = "^1.8.1"} aws-xray-sdk = "~2.13" bcrypt = "^4.0.1" boto3 = "^1.28" +celery = {extras = ["redis", "tblib"], version = "^5.3.6"} certifi = "*" click = "^8.1.3" contextlib2 = "21.6.0" @@ -255,6 +277,7 @@ psycopg2-binary = "~2.9.5" pyfakefs = "^5.3" pytest = ">=7.2.0" pytest-alembic = "^0.11.0" +pytest-celery = "^0.0.0" pytest-cov = "^5.0.0" pytest-timeout = "*" requests-mock = "1.12.1" diff --git a/tests/api/admin/controller/test_collections.py b/tests/api/admin/controller/test_collections.py index db0a62474e..428f5944e4 100644 --- a/tests/api/admin/controller/test_collections.py +++ b/tests/api/admin/controller/test_collections.py @@ -1,5 +1,5 @@ import json -from unittest.mock import MagicMock, create_autospec +from unittest.mock import MagicMock, create_autospec, patch import flask import pytest @@ -706,7 +706,12 @@ def test_collection_delete( collection.integration_configuration.id, ) - with flask_app_fixture.test_request_context_system_admin("/", method="DELETE"): + with ( + flask_app_fixture.test_request_context_system_admin("/", method="DELETE"), + patch( + "api.admin.controller.collection_settings.collection_delete" + ) as mock_delete, + ): assert collection.integration_configuration.id is not None response = controller.process_delete( collection.integration_configuration.id @@ -721,6 +726,10 @@ def test_collection_delete( assert fetched_collection == collection assert fetched_collection.marked_for_deletion is True + # The controller called collection_delete with the correct arguments, to + # queue up the collection for deletion in the background. + mock_delete.delay.assert_called_once_with(collection.id) + def test_collection_delete_cant_delete_parent( self, controller: CollectionSettingsController, diff --git a/tests/api/admin/controller/test_report.py b/tests/api/admin/controller/test_report.py index 72ff963c9f..ba6fecafad 100644 --- a/tests/api/admin/controller/test_report.py +++ b/tests/api/admin/controller/test_report.py @@ -1,7 +1,8 @@ -import json from http import HTTPStatus +from unittest.mock import patch import pytest +from flask import Response from core.model import create from core.model.admin import Admin, AdminRole @@ -28,23 +29,21 @@ def test_generate_inventory_report(self, report_fixture: ReportControllerFixture report_fixture.ctrl.library = report_fixture.ctrl.db.default_library() system_admin, _ = create(db.session, Admin, email="admin@email.com") system_admin.add_role(AdminRole.SYSTEM_ADMIN) - with report_fixture.request_context_with_library_and_admin( - f"/", - admin=system_admin, - ) as ctx: + with ( + report_fixture.request_context_with_library_and_admin( + f"/", + admin=system_admin, + ), + patch( + "api.admin.controller.report.generate_inventory_reports" + ) as mock_generate, + ): response = ctrl.generate_inventory_report() + mock_generate.delay.assert_called_once_with( + library_id=report_fixture.ctrl.library.id, + admin_email=system_admin.email, + ) + assert isinstance(response, Response) assert response.status_code == HTTPStatus.ACCEPTED - body = json.loads(response.data) # type: ignore - assert body and body["message"].__contains__("admin@email.com") - assert not body.__contains__("already") - - # check that when generating a duplicate request a 409 is returned. - with report_fixture.request_context_with_library_and_admin( - f"/", - admin=system_admin, - ) as ctx: - response = ctrl.generate_inventory_report() - body = json.loads(response.data) # type: ignore - assert response.status_code == HTTPStatus.CONFLICT - assert body and body["message"].__contains__("admin@email.com") - assert body["message"].__contains__("already") + assert isinstance(response.json, dict) + assert system_admin.email in response.json["message"] diff --git a/tests/core/celery/__init__.py b/tests/core/celery/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/celery/tasks/__init__.py b/tests/core/celery/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/celery/tasks/test_collection_delete.py b/tests/core/celery/tasks/test_collection_delete.py new file mode 100644 index 0000000000..9c4a02b994 --- /dev/null +++ b/tests/core/celery/tasks/test_collection_delete.py @@ -0,0 +1,61 @@ +from logging import INFO + +from _pytest.logging import LogCaptureFixture +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.celery.tasks.collection_delete import CollectionDeleteJob, collection_delete +from core.model import Collection +from tests.fixtures.celery import CeleryFixture +from tests.fixtures.database import DatabaseTransactionFixture + + +def test_collection_delete_job_collection(db: DatabaseTransactionFixture): + # A non-existent collection should return None + assert CollectionDeleteJob.collection(db.session, 1) is None + + collection = db.collection(name="collection1") + assert collection.id is not None + assert CollectionDeleteJob.collection(db.session, collection.id) == collection + + +def test_collection_delete_job_collection_name(db: DatabaseTransactionFixture): + collection = db.collection(name="collection1") + assert ( + CollectionDeleteJob.collection_name(collection) + == f"{collection.name}/{collection.protocol} ({collection.id})" + ) + + +def test_collection_delete_job_run( + db: DatabaseTransactionFixture, + mock_session_maker: sessionmaker, + caplog: LogCaptureFixture, +): + # A non-existent collection should log an error + caplog.set_level(INFO) + CollectionDeleteJob(mock_session_maker, 1).run() + assert "Collection with id 1 not found. Unable to delete." in caplog.text + + collection = db.collection(name="collection1") + collection.marked_for_deletion = True + query = select(Collection).where(Collection.id == collection.id) + + assert db.session.execute(query).scalar_one_or_none() == collection + + assert collection.id is not None + job = CollectionDeleteJob(mock_session_maker, collection.id) + job.run() + assert db.session.execute(query).scalar_one_or_none() is None + assert f"Deleting collection" in caplog.text + + +def test_collection_delete_task( + db: DatabaseTransactionFixture, celery_fixture: CeleryFixture +): + collection = db.collection(name="collection1") + collection.marked_for_deletion = True + query = select(Collection).where(Collection.id == collection.id) + assert db.session.execute(query).scalar_one_or_none() == collection + collection_delete.delay(collection.id).wait() + assert db.session.execute(query).scalar_one_or_none() is None diff --git a/tests/core/celery/tasks/test_inventory_reports.py b/tests/core/celery/tasks/test_inventory_reports.py new file mode 100644 index 0000000000..5b67707047 --- /dev/null +++ b/tests/core/celery/tasks/test_inventory_reports.py @@ -0,0 +1,142 @@ +import csv +from io import StringIO +from unittest.mock import MagicMock, create_autospec + +from _pytest.monkeypatch import MonkeyPatch +from sqlalchemy.orm import sessionmaker + +from core.celery.tasks.inventory_reports import ( + InventoryReportsJob, + generate_inventory_reports, +) +from core.opds_import import OPDSImporterSettings +from tests.fixtures.celery import CeleryFixture +from tests.fixtures.database import DatabaseTransactionFixture + + +def test_generate_inventory_reports_job( + db: DatabaseTransactionFixture, mock_session_maker: sessionmaker +): + # create some test data that we expect to be picked up in the inventory report + library = db.library(short_name="test") + settings = OPDSImporterSettings( + include_in_inventory_report=True, + external_account_id="http://opds.com", + data_source="BiblioBoard", + ) + collection = db.collection( + name="BiblioBoard Test Collection", settings=settings.dict() + ) + collection.libraries = [library] + + # Configure test data we expect will not be picked up. + no_inventory_report_settings = OPDSImporterSettings( + include_in_inventory_report=False, + external_account_id="http://opds.com", + data_source="AnotherOpdsDataSource", + ) + collection_not_to_include = db.collection( + name="Another Test Collection", settings=no_inventory_report_settings.dict() + ) + collection_not_to_include.libraries = [library] + + ds = collection.data_source + assert ds + title = "Leaves of Grass" + author = "Walt Whitman" + email = "test@email.com" + checkouts_left = 10 + checkouts_available = 11 + terms_concurrency = 5 + edition = db.edition(data_source_name=ds.name) + edition.title = title + edition.author = author + db.work( + language="eng", + fiction=True, + with_license_pool=False, + data_source_name=ds.name, + presentation_edition=edition, + collection=collection, + ) + licensepool = db.licensepool( + edition=edition, + open_access=False, + data_source_name=ds.name, + set_edition_as_presentation=True, + collection=collection, + ) + + db.license( + pool=licensepool, + checkouts_available=checkouts_available, + checkouts_left=checkouts_left, + terms_concurrency=terms_concurrency, + ) + + assert library.id + + send_email_mock = MagicMock() + InventoryReportsJob(mock_session_maker, send_email_mock, library.id, email).run() + + send_email_mock.assert_called_once() + assert send_email_mock.call_args.kwargs["receivers"] == [email] + assert "Inventory Report" in send_email_mock.call_args.kwargs["subject"] + attachments = send_email_mock.call_args.kwargs["attachments"] + + assert len(attachments) == 1 + key, value = next(iter(attachments.items())) + assert "biblioboard" in key + assert len(value) > 0 + csv_file = StringIO(value) + reader = csv.DictReader(csv_file, delimiter=",") + + assert reader.fieldnames == [ + "title", + "author", + "identifier", + "language", + "publisher", + "format", + "collection_name", + "license_duration_days", + "license_expiration_date", + "initial_loan_count", + "consumed_loans", + "remaining_loans", + "allowed_concurrent_users", + "library_active_hold_count", + "library_active_loan_count", + "shared_active_hold_count", + "shared_active_loan_count", + ] + + rows = list(reader) + assert len(rows) == 1 + row = rows[0] + + assert row["title"] == title + assert row["author"] == author + assert row["shared_active_hold_count"] == "-1" + assert row["shared_active_loan_count"] == "-1" + assert row["initial_loan_count"] == str(checkouts_available) + assert row["consumed_loans"] == str(checkouts_available - checkouts_left) + assert row["allowed_concurrent_users"] == str(terms_concurrency) + + +def test_generate_inventory_reports_task( + db: DatabaseTransactionFixture, + celery_fixture: CeleryFixture, + monkeypatch: MonkeyPatch, +): + mock_job = create_autospec(generate_inventory_reports) + monkeypatch.setattr( + "core.celery.tasks.inventory_reports.InventoryReportsJob", + mock_job, + ) + + generate_inventory_reports.delay(library_id=1, admin_email="test@test.com").wait() + mock_job.assert_called_once() + assert mock_job.call_args.kwargs["library_id"] == 1 + assert mock_job.call_args.kwargs["admin_email"] == "test@test.com" + mock_job.return_value.run.assert_called_once_with() diff --git a/tests/core/conftest.py b/tests/core/conftest.py index bb6feff578..dfb1327774 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -2,6 +2,7 @@ pytest_plugins = [ "tests.fixtures.announcements", + "tests.fixtures.celery", "tests.fixtures.csv_files", "tests.fixtures.database", "tests.fixtures.library", diff --git a/tests/core/service/logging/test_log.py b/tests/core/service/logging/test_log.py index 2a3b6016d5..a6a617f741 100644 --- a/tests/core/service/logging/test_log.py +++ b/tests/core/service/logging/test_log.py @@ -116,6 +116,27 @@ def test_format_args( # The resulting data is always a Unicode string. assert data["message"] == expected + def test_format_celery_worker(self, log_record: LogRecordCallable): + with patch( + "core.service.logging.log.get_current_task" + ) as mock_get_current_task: + formatter = JSONFormatter() + record = log_record() + + # if we are in a celery task, we should include the task ID and task name + mock_get_current_task.return_value.configure_mock( + **{"name": "task_name", "request.id": "task_id"} + ) + data = json.loads(formatter.format(record)) + assert data["task_id"] == "task_id" + assert data["task_name"] == "task_name" + + # otherwise, they should not be included + mock_get_current_task.return_value = None + data = json.loads(formatter.format(record)) + assert "task_id" not in data + assert "task_name" not in data + class TestLogLoopPreventionFilter: @pytest.mark.parametrize( diff --git a/tests/core/test_app_server.py b/tests/core/test_app_server.py index 7200885335..b1642cf0a7 100644 --- a/tests/core/test_app_server.py +++ b/tests/core/test_app_server.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Iterable from functools import partial from io import BytesIO -from unittest.mock import MagicMock, PropertyMock +from unittest.mock import MagicMock, PropertyMock, patch import flask import pytest @@ -129,15 +129,11 @@ def test_process_urns_hook_method( urn_lookup_handler_fixture.transaction.session, ) - # Verify that process_urns() calls post_lookup_hook() once - # it's done. - class Mock(URNLookupHandler): - def post_lookup_hook(self): - self.called = True + with patch.object(URNLookupHandler, "post_lookup_hook") as mock_hook: + handler = URNLookupHandler(session) + handler.process_urns([]) - handler = Mock(session) - handler.process_urns([]) - assert True == handler.called + mock_hook.assert_called_once() def test_process_urns_invalid_urn( self, urn_lookup_handler_fixture: URNLookupHandlerFixture @@ -537,6 +533,7 @@ def test_unhandled_error(self, error_handler_fixture: ErrorHandlerFixture): def test_handle_error_as_problem_detail_document( self, error_handler_fixture: ErrorHandlerFixture, caplog: LogCaptureFixture ): + caplog.set_level(LogLevel.error.value) handler = error_handler_fixture.handler() with error_handler_fixture.app.test_request_context("/"): try: @@ -562,6 +559,7 @@ def test_handle_error_as_problem_detail_document( def test_handle_error_problem_error( self, error_handler_fixture: ErrorHandlerFixture, caplog: LogCaptureFixture ): + caplog.set_level(LogLevel.warning.value) handler = error_handler_fixture.handler() with error_handler_fixture.app.test_request_context("/"): try: @@ -587,6 +585,7 @@ def test_handle_error_problem_error( def test_handle_operational_error( self, error_handler_fixture, caplog: LogCaptureFixture ): + caplog.set_level(LogLevel.warning.value) handler = error_handler_fixture.handler() with error_handler_fixture.app.test_request_context("/"): try: diff --git a/tests/core/test_opds2_import.py b/tests/core/test_opds2_import.py index 33e07d16b6..23d9fdcd79 100644 --- a/tests/core/test_opds2_import.py +++ b/tests/core/test_opds2_import.py @@ -1,5 +1,6 @@ import datetime from collections.abc import Generator +from logging import WARNING from unittest.mock import MagicMock, patch import pytest @@ -628,6 +629,7 @@ def test_token_fulfill_no_content_link( self, opds2_api_fixture: Opds2ApiFixture, caplog: LogCaptureFixture ): # No content_link on the fulfillment info coming into the function + caplog.set_level(WARNING) mock = MagicMock(spec=FulfillmentInfo) mock.content_link = None response = opds2_api_fixture.api.fulfill_token_auth( diff --git a/tests/core/test_scripts.py b/tests/core/test_scripts.py index de7654ab0f..90547d120b 100644 --- a/tests/core/test_scripts.py +++ b/tests/core/test_scripts.py @@ -1,11 +1,8 @@ from __future__ import annotations -import csv import datetime import json -import logging import random -from dataclasses import asdict from io import StringIO from unittest.mock import MagicMock, call, create_autospec, patch @@ -40,18 +37,10 @@ ) from core.model.classification import Classification, Subject from core.model.customlist import CustomList -from core.model.deferredtask import ( - DeferredTask, - DeferredTaskStatus, - DeferredTaskType, - InventoryReportTaskData, - queue_task, - start_next_task, -) from core.model.devicetokens import DeviceToken, DeviceTokenTypes from core.model.patron import Patron from core.monitor import CollectionMonitor, Monitor, ReaperMonitor -from core.opds_import import OPDSAPI, OPDSImporterSettings, OPDSImportMonitor +from core.opds_import import OPDSAPI, OPDSImportMonitor from core.scripts import ( AddClassificationScript, CheckContributorNamesInDB, @@ -62,9 +51,7 @@ ConfigureLibraryScript, CustomListUpdateEntriesScript, DeleteInvisibleLanesScript, - DeleteOldDeferredTasks, Explain, - GenerateInventoryReports, IdentifierInputScript, LaneSweeperScript, LibraryInputScript, @@ -2577,194 +2564,6 @@ def test_suppress_work(self, db: DatabaseTransactionFixture): assert work.suppressed_for == [test_library] -class TestGenerateInventoryReports: - def test_do_run(self, db: DatabaseTransactionFixture): - # create some test data that we expect to be picked up in the inventory report - library = db.library(short_name="test") - settings = OPDSImporterSettings( - include_in_inventory_report=True, - external_account_id="http://opds.com", - data_source="BiblioBoard", - ) - collection = db.collection( - name="BiblioBoard Test Collection", settings=settings.dict() - ) - collection.libraries = [library] - - # Configure test data we expect will not be picked up. - no_inventory_report_settings = OPDSImporterSettings( - include_in_inventory_report=False, - external_account_id="http://opds.com", - data_source="AnotherOpdsDataSource", - ) - collection_not_to_include = db.collection( - name="Another Test Collection", settings=no_inventory_report_settings.dict() - ) - collection_not_to_include.libraries = [library] - - ds = collection.data_source - assert ds - title = "Leaves of Grass" - author = "Walt Whitman" - email = "test@email.com" - checkouts_left = 10 - checkouts_available = 11 - terms_concurrency = 5 - edition = db.edition(data_source_name=ds.name) - edition.title = title - edition.author = author - db.work( - language="eng", - fiction=True, - with_license_pool=False, - data_source_name=ds.name, - presentation_edition=edition, - collection=collection, - ) - licensepool = db.licensepool( - edition=edition, - open_access=False, - data_source_name=ds.name, - set_edition_as_presentation=True, - collection=collection, - ) - - db.license( - pool=licensepool, - checkouts_available=checkouts_available, - checkouts_left=checkouts_left, - terms_concurrency=terms_concurrency, - ) - - assert library.id - data = InventoryReportTaskData( - admin_id=1, library_id=library.id, admin_email=email - ) - task, is_new = queue_task( - db.session, task_type=DeferredTaskType.INVENTORY_REPORT, data=asdict(data) - ) - - assert task.status == DeferredTaskStatus.READY - - script = GenerateInventoryReports(db.session) - send_email_mock = create_autospec(script.services.email.container.send_email) - script.services.email.container.send_email = send_email_mock - script.do_run() - send_email_mock.assert_called_once() - args, kwargs = send_email_mock.call_args - assert task.status == DeferredTaskStatus.SUCCESS - assert kwargs["receivers"] == [email] # type:ignore[unreachable] - assert kwargs["subject"].__contains__("Inventory Report") - attachments: dict = kwargs["attachments"] - - assert len(attachments) == 1 - key = [*attachments.keys()][0] - assert "biblioboard" in key - value = attachments[key] - assert len(value) > 0 - csv_file = StringIO(value) - reader = csv.reader(csv_file, delimiter=",") - first_row = None - row_count = 0 - - for row in reader: - row_count += 1 - if not first_row: - first_row = row - row_headers = [ - "title", - "author", - "identifier", - "language", - "publisher", - "format", - "collection_name", - "license_duration_days", - "license_expiration_date", - "initial_loan_count", - "consumed_loans", - "remaining_loans", - "allowed_concurrent_users", - "library_active_hold_count", - "library_active_loan_count", - "shared_active_hold_count", - "shared_active_loan_count", - ] - for h in row_headers: - assert h in row - continue - - assert row[first_row.index("title")] == title - assert row[first_row.index("author")] == author - assert row[first_row.index("shared_active_hold_count")] == "-1" - assert row[first_row.index("shared_active_loan_count")] == "-1" - assert row[first_row.index("initial_loan_count")] == str( - checkouts_available - ) - assert row[first_row.index("consumed_loans")] == str( - checkouts_available - checkouts_left - ) - assert row[first_row.index("allowed_concurrent_users")] == str( - terms_concurrency - ) - - assert row_count == 2 - - -class TestDeleteOldDeferredTasks: - def test_do_run( - self, db: DatabaseTransactionFixture, caplog: pytest.LogCaptureFixture - ): - caplog.set_level(logging.INFO) - # create some test data - _db = db.session - - # the deferred task table should be empty - task = _db.query(DeferredTask).first() - assert not task - - data = InventoryReportTaskData( - admin_id=1, library_id=1, admin_email="test@email.com" - ) - task, is_new = queue_task( - db.session, task_type=DeferredTaskType.INVENTORY_REPORT, data=asdict(data) - ) - - assert task - assert is_new - - task2 = start_next_task(_db, task_type=DeferredTaskType.INVENTORY_REPORT) - - assert task2 - # sanity check: make sure we got back the task we just created. - assert task2.id == task.id - assert task2.processing_start_time - # make sure it is processing - assert task2.status == DeferredTaskStatus.PROCESSING - task2.complete() - assert task2.processing_end_time - _db.commit() - - # run it with the expection of no tasks to be deleted. - DeleteOldDeferredTasks(_db).do_run() - assert caplog.messages[0].__contains__("There were no deferred tasks") - - # set the task's end processing time to 30 days ago. - task3 = _db.query(DeferredTask).filter(DeferredTask.id == task2.id).first() - assert task3 - task3.processing_end_time = task3.processing_end_time - datetime.timedelta( - days=30 - ) - _db.commit() - # run again with the expectation of 1 task removed. - DeleteOldDeferredTasks(_db).do_run() - task4 = _db.query(DeferredTask).first() - assert not task4 - assert caplog.messages[1].__contains__( - "Successfully removed 1 task that were completed over 30 days ago." - ) - - class TestWorkConsolidationScript: """TODO""" diff --git a/tests/fixtures/celery.py b/tests/fixtures/celery.py new file mode 100644 index 0000000000..4251c60bdf --- /dev/null +++ b/tests/fixtures/celery.py @@ -0,0 +1,97 @@ +import os +from collections.abc import Generator, Mapping +from dataclasses import dataclass +from typing import Any +from unittest.mock import PropertyMock, patch + +import pytest +from celery import Celery +from celery.worker import WorkController + +from core.celery.task import Task +from core.service.celery.celery import task_queue_config +from core.service.celery.configuration import CeleryConfiguration +from core.service.celery.container import CeleryContainer +from tests.fixtures.database import MockSessionMaker +from tests.fixtures.services import ServicesFixture + + +@pytest.fixture(scope="session") +def celery_worker_parameters() -> Mapping[str, Any]: + """ + Change the init parameters of Celery workers. + + Normally when testing, we want to make sure that if there is an issue with the task + the worker will shut down after a certain amount of time. We default this to 30 sec. + However, when debugging it can be useful to set this to None, so you can set breakpoints + in the worker code, without the worker timing out and shutting down. + """ + timeout = os.environ.get( + "PALACE_TEST_CELERY_WORKER_SHUTDOWN_TIMEOUT", "30.0" + ).lower() + shutdown_timeout = None if timeout == "none" or timeout == "" else float(timeout) + return {"shutdown_timeout": shutdown_timeout} + + +@pytest.fixture(scope="session") +def celery_pydantic_config() -> CeleryConfiguration: + """Configure the test Celery app. + + The config returned will then be used to configure the `celery_app` fixture. + """ + return CeleryConfiguration.construct(broker_url="memory://", cm_name="test") # type: ignore[arg-type] + + +@pytest.fixture(scope="session") +def celery_config(celery_pydantic_config: CeleryConfiguration) -> Mapping[str, Any]: + """Configure the test Celery app. + + The config returned will then be used to configure the `celery_app` fixture. + """ + cm_name = celery_pydantic_config.cm_name + return celery_pydantic_config.dict() | task_queue_config(cm_name) + + +@pytest.fixture(scope="session") +def celery_parameters() -> Mapping[str, Any]: + """Change the init parameters of test Celery app. + + The dict returned will be used as parameters when instantiating `~celery.Celery`. + """ + return {"task_cls": "core.celery.task:Task"} + + +@dataclass +class CeleryFixture: + container: CeleryContainer + app: Celery + config: CeleryConfiguration + worker: WorkController + + +@pytest.fixture() +def celery_fixture( + services_fixture: ServicesFixture, + mock_session_maker: MockSessionMaker, + celery_app: Celery, + celery_worker: WorkController, + celery_pydantic_config: CeleryConfiguration, +) -> Generator[CeleryFixture, None, None]: + """Fixture to provide a Celery app and worker for testing.""" + + # Make sure our services container has the correct celery app setup + container = services_fixture.celery_fixture.celery_container + container.config.from_dict(celery_pydantic_config.dict()) + container.app.override(celery_app) + + # We mock out the session maker, so it doesn't try to create a new session, + # instead it should use the same session as the test transaction. + with ( + patch.object(Task, "_session_maker", mock_session_maker), + patch.object( + Task, "services", PropertyMock(return_value=services_fixture.services) + ), + ): + yield CeleryFixture( + container, celery_app, celery_pydantic_config, celery_worker + ) diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 1dd8367080..73094a9147 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -8,6 +8,7 @@ import time import uuid from collections.abc import Generator, Iterable +from contextlib import contextmanager from textwrap import dedent from typing import Any @@ -16,7 +17,7 @@ from Crypto.PublicKey.RSA import import_key from sqlalchemy import MetaData from sqlalchemy.engine import Connection, Engine, Transaction -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker import core.lane from api.discovery.opds_registration import OpdsRegistrationService @@ -1000,6 +1001,24 @@ def create_integration_library_configuration( return fixture +class MockSessionMaker: + def __init__(self, session: Session): + self._session = session + + def __call__(self) -> Session: + return self._session + + @contextmanager + def begin(self) -> Generator[Session, None, None]: + with self._session.begin_nested(): + yield self._session + + +@pytest.fixture +def mock_session_maker(db: DatabaseTransactionFixture) -> sessionmaker[Session]: + return MockSessionMaker(db.session) # type: ignore[return-value] + + class DBStatementCounter: """ Use as a context manager to count the number of execute()'s performed diff --git a/tests/fixtures/services.py b/tests/fixtures/services.py index dc075b304b..5ab9513b42 100644 --- a/tests/fixtures/services.py +++ b/tests/fixtures/services.py @@ -8,12 +8,14 @@ import boto3 import pytest +from celery import Celery from core.analytics import Analytics from core.external_search import ExternalSearchIndex from core.search.revision_directory import SearchRevisionDirectory from core.search.service import SearchServiceOpensearch1 from core.service.analytics.container import AnalyticsContainer +from core.service.celery.container import CeleryContainer from core.service.container import Services, wire_container from core.service.email.configuration import EmailConfiguration from core.service.email.container import Email @@ -128,6 +130,20 @@ def services_email_fixture() -> ServicesEmailFixture: return ServicesEmailFixture(email_container, mock_emailer, sender_email) +@dataclass +class ServicesCeleryFixture: + celery_container: CeleryContainer + app: Celery + + +@pytest.fixture +def services_celery_fixture() -> ServicesCeleryFixture: + celery_container = CeleryContainer() + celery_mock_app = MagicMock() + celery_container.app.override(celery_mock_app) + return ServicesCeleryFixture(celery_container, celery_mock_app) + + class ServicesFixture: """ Provide a real services container, with all services mocked out. @@ -140,12 +156,14 @@ def __init__( search: ServicesSearchFixture, analytics: ServicesAnalyticsFixture, email: ServicesEmailFixture, + celery: ServicesCeleryFixture, ) -> None: self.logging_fixture = logging self.storage_fixture = storage self.search_fixture = search self.analytics_fixture = analytics self.email_fixture = email + self.celery_fixture = celery self.services = Services() self.services.logging.override(logging.logging_container) @@ -153,6 +171,7 @@ def __init__( self.services.search.override(search.search_container) self.services.analytics.override(analytics.analytics_container) self.services.email.override(email.email_container) + self.services.celery.override(celery.celery_container) # setup basic configuration from default settings self.services.config.from_dict({"sitewide": SitewideConfiguration().dict()}) @@ -189,6 +208,7 @@ def services_fixture( services_search_fixture: ServicesSearchFixture, services_analytics_fixture: ServicesAnalyticsFixture, services_email_fixture: ServicesEmailFixture, + services_celery_fixture: ServicesCeleryFixture, ) -> Generator[ServicesFixture, None, None]: fixture = ServicesFixture( logging=services_logging_fixture, @@ -196,6 +216,7 @@ def services_fixture( search=services_search_fixture, analytics=services_analytics_fixture, email=services_email_fixture, + celery=services_celery_fixture, ) with mock_services_container(fixture.services): yield fixture