Skip to content

Commit

Permalink
[PP-1216] Zip email reports (#1818)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbernstein authored May 3, 2024
1 parent c733da5 commit f3ad4c7
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 127 deletions.
171 changes: 96 additions & 75 deletions src/palace/manager/celery/tasks/generate_inventory_and_hold_reports.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import csv
import os
import tempfile
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import IO, Any

from celery import shared_task
from sqlalchemy import not_, select, text
Expand Down Expand Up @@ -54,90 +54,111 @@ def run(self) -> None:
)
return

try:
current_time = datetime.now()
date_str = current_time.strftime("%Y-%m-%d_%H:%M:%s")
attachments: dict[str, Path] = {}

file_name_modifier = f"{library.short_name}-{date_str}"
self.log.info(
f"Starting inventory and holds report job for {library.name}({library.short_name})."
)

# resolve integrations
integrations = session.scalars(
select(IntegrationConfiguration)
.join(IntegrationLibraryConfiguration)
.where(
IntegrationLibraryConfiguration.library_id == self.library_id,
IntegrationConfiguration.goal == Goals.LICENSE_GOAL,
not_(
IntegrationConfiguration.settings_dict.contains(
{"include_in_inventory_report": False}
)
),
)
).all()
registry = LicenseProvidersRegistry()
integration_ids: list[int] = []
for integration in integrations:
settings = registry[integration.protocol].settings_load(integration)
if not isinstance(settings, OPDSImporterSettings):
continue
integration_ids.append(integration.id)
current_time = datetime.now()
date_str = current_time.strftime("%Y-%m-%d_%H:%M:%s")

# generate inventory report csv file
sql_params: dict[str, Any] = {
"library_id": library.id,
"integration_ids": tuple(integration_ids),
}
file_name_modifier = f"{library.short_name}-{date_str}"

inventory_report_file_path = self.generate_inventory_report(
session, sql_params=sql_params
# resolve integrations
integrations = session.scalars(
select(IntegrationConfiguration)
.join(IntegrationLibraryConfiguration)
.where(
IntegrationLibraryConfiguration.library_id == self.library_id,
IntegrationConfiguration.goal == Goals.LICENSE_GOAL,
not_(
IntegrationConfiguration.settings_dict.contains(
{"include_in_inventory_report": False}
)
),
)
).all()
registry = LicenseProvidersRegistry()
integration_ids: list[int] = []
for integration in integrations:
settings = registry[integration.protocol].settings_load(integration)
if not isinstance(settings, OPDSImporterSettings):
continue
integration_ids.append(integration.id)

# generate holds report csv file
holds_report_file_path = self.generate_holds_report(
session, sql_params=sql_params
)
# generate inventory report csv file
sql_params: dict[str, Any] = {
"library_id": library.id,
"integration_ids": tuple(integration_ids),
}

attachments[f"palace-inventory-report-{file_name_modifier}.csv"] = Path(
inventory_report_file_path
)
attachments[f"palace-holds-report-{file_name_modifier}.csv"] = Path(
holds_report_file_path
)
with tempfile.NamedTemporaryFile(
delete=self.delete_attachments
) as report_zip:
zip_path = Path(report_zip.name)

self.send_email(
subject=f"Inventory and Holds Reports {current_time}",
receivers=[self.email_address],
text="",
attachments=attachments,
)
finally:
if self.delete_attachments:
for file_path in attachments.values():
os.remove(file_path)
with (
self.create_temp_file() as inventory_report_file,
self.create_temp_file() as holds_report_file,
):
self.generate_csv_report(
session,
csv_file=inventory_report_file,
sql_params=sql_params,
query=self.inventory_report_query(),
)

def generate_inventory_report(
self, _db: Session, sql_params: dict[str, Any]
) -> str:
"""Generate an inventory csv file and return the file path"""
return self.generate_csv_report(_db, sql_params, self.inventory_report_query())
self.generate_csv_report(
session,
csv_file=holds_report_file,
sql_params=sql_params,
query=self.holds_report_query(),
)

with zipfile.ZipFile(
zip_path, "w", zipfile.ZIP_DEFLATED
) as archive:
archive.write(
filename=holds_report_file.name,
arcname=f"palace-holds-report-for-library-{file_name_modifier}.csv",
)
archive.write(
filename=inventory_report_file.name,
arcname=f"palace-inventory-report-for-library-{file_name_modifier}.csv",
)

self.send_email(
subject=f"Inventory and Holds Reports {current_time}",
receivers=[self.email_address],
text="",
attachments={
f"palace-inventory-and-holds-reports-for-{file_name_modifier}.zip": zip_path
},
)

def generate_holds_report(self, _db: Session, sql_params: dict[str, Any]) -> str:
"""Generate a holds report csv file and return the file path"""
return self.generate_csv_report(_db, sql_params, self.holds_report_query())
self.log.debug(f"Zip file written to {zip_path}")
self.log.info(
f"Emailed inventory and holds reports for {library.name}({library.short_name})."
)

def create_temp_file(self) -> IO[str]:
return tempfile.NamedTemporaryFile("w", encoding="utf-8")

def generate_csv_report(
self, _db: Session, sql_params: dict[str, Any], query: str
) -> str:
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as temp:
writer = csv.writer(temp, delimiter=",")
rows = _db.execute(
text(query),
sql_params,
)
writer.writerow(rows.keys())
writer.writerows(rows)
return temp.name
self,
_db: Session,
csv_file: IO[str],
sql_params: dict[str, Any],
query: str,
) -> None:
writer = csv.writer(csv_file, delimiter=",")
rows = _db.execute(
text(query),
sql_params,
)
writer.writerow(rows.keys())
writer.writerows(rows)
csv_file.flush()
self.log.debug(f"report written to {csv_file.name}")

@staticmethod
def inventory_report_query() -> str:
Expand Down
123 changes: 71 additions & 52 deletions tests/manager/celery/tasks/test_generate_inventory_and_hold_reports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import csv
import io
import os
import zipfile
from datetime import timedelta
from typing import IO
from unittest.mock import create_autospec

from pytest import LogCaptureFixture
Expand Down Expand Up @@ -199,58 +202,74 @@ def test_job_run(
assert "Inventory and Holds Reports" in kwargs["subject"]
attachments: dict = kwargs["attachments"]

assert len(attachments) == 2
inventory_report_key = [x for x in attachments.keys() if "inventory" in x][0]
assert inventory_report_key
assert "test_library" in inventory_report_key
inventory_report_value = attachments[inventory_report_key]
assert inventory_report_value
inventory_report_csv = list(csv.DictReader(open(inventory_report_value)))

assert len(inventory_report_csv) == 1
for row in inventory_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert float(row["days_remaining_on_license"]) == float(days_remaining)
assert row["shared_active_loan_count"] == "0"
assert row["library_active_loan_count"] == "0"
assert row["remaining_loans"] == str(checkouts_left)
assert row["allowed_concurrent_users"] == str(terms_concurrency)
assert expiration.strftime("%Y-%m-%d %H:%M:%S.%f") in row["license_expiration"]

holds_report_key = [x for x in attachments.keys() if "holds" in x][0]
assert holds_report_key
assert "test_library" in holds_report_key
holds_report_value = attachments[holds_report_key]
assert holds_report_value
holds_report_csv = list(csv.DictReader(open(holds_report_value)))
assert len(holds_report_csv) == 1

for row in holds_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert int(row["shared_active_hold_count"]) == shared_patrons_in_hold_queue
assert int(row["library_active_hold_count"]) == 3

# clean up files
for f in attachments.values():
os.remove(f)
assert len(attachments) == 1
reports_zip = list(attachments.values())[0]
try:
with zipfile.ZipFile(reports_zip, mode="r") as archive:
entry_list = archive.namelist()
assert len(entry_list) == 2
with (
archive.open(entry_list[0]) as holds_report_zip_entry,
archive.open(entry_list[1]) as inventory_report_zip_entry,
):
assert inventory_report_zip_entry
assert "test_library" in inventory_report_zip_entry.name
inventory_report_csv = zip_csv_entry_to_dict(inventory_report_zip_entry)

assert len(inventory_report_csv) == 1
for row in inventory_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert float(row["days_remaining_on_license"]) == float(
days_remaining
)
assert row["shared_active_loan_count"] == "0"
assert row["library_active_loan_count"] == "0"
assert row["remaining_loans"] == str(checkouts_left)
assert row["allowed_concurrent_users"] == str(terms_concurrency)
assert (
expiration.strftime("%Y-%m-%d %H:%M:%S.%f")
in row["license_expiration"]
)

assert holds_report_zip_entry
assert "test_library" in holds_report_zip_entry.name
assert holds_report_zip_entry
holds_report_csv = zip_csv_entry_to_dict(holds_report_zip_entry)
assert len(holds_report_csv) == 1

for row in holds_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert (
int(row["shared_active_hold_count"])
== shared_patrons_in_hold_queue
)
assert int(row["library_active_hold_count"]) == 3
finally:
os.remove(reports_zip)


def zip_csv_entry_to_dict(zip_entry: IO[bytes]):
wrapper = io.TextIOWrapper(zip_entry, encoding="UTF-8")
csv_dict = list(csv.DictReader(wrapper))
return csv_dict


def create_test_opds_collection(
Expand Down

0 comments on commit f3ad4c7

Please sign in to comment.