Skip to content

Commit

Permalink
Rework the PR to ensure all file resources are managed with context m…
Browse files Browse the repository at this point in the history
…anagers where possible.
  • Loading branch information
dbernstein committed May 3, 2024
1 parent a4186d4 commit 8ae7482
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 146 deletions.
171 changes: 85 additions & 86 deletions src/palace/manager/celery/tasks/generate_inventory_and_hold_reports.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import csv
import os
import tempfile
import zipfile
from datetime import datetime
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import Any

from celery import shared_task
Expand Down Expand Up @@ -59,108 +59,107 @@ def run(self) -> None:
f"Starting inventory and holds report job for {library.name}({library.short_name})."
)

try:
current_time = datetime.now()
date_str = current_time.strftime("%Y-%m-%d_%H:%M:%s")
attachments: dict[str, Path] = {}
current_time = datetime.now()
date_str = current_time.strftime("%Y-%m-%d_%H:%M:%s")

file_name_modifier = f"{library.short_name}-{date_str}"
file_name_modifier = f"{library.short_name}-{date_str}"

# resolve integrations
integrations = session.scalars(
select(IntegrationConfiguration)
.join(IntegrationLibraryConfiguration)
.where(
IntegrationLibraryConfiguration.library_id == self.library_id,
IntegrationConfiguration.goal == Goals.LICENSE_GOAL,
not_(
IntegrationConfiguration.settings_dict.contains(
{"include_in_inventory_report": False}
)
),
)
).all()
registry = LicenseProvidersRegistry()
integration_ids: list[int] = []
for integration in integrations:
settings = registry[integration.protocol].settings_load(integration)
if not isinstance(settings, OPDSImporterSettings):
continue
integration_ids.append(integration.id)
# resolve integrations
integrations = session.scalars(
select(IntegrationConfiguration)
.join(IntegrationLibraryConfiguration)
.where(
IntegrationLibraryConfiguration.library_id == self.library_id,
IntegrationConfiguration.goal == Goals.LICENSE_GOAL,
not_(
IntegrationConfiguration.settings_dict.contains(
{"include_in_inventory_report": False}
)
),
)
).all()
registry = LicenseProvidersRegistry()
integration_ids: list[int] = []
for integration in integrations:
settings = registry[integration.protocol].settings_load(integration)
if not isinstance(settings, OPDSImporterSettings):
continue
integration_ids.append(integration.id)

# generate inventory report csv file
sql_params: dict[str, Any] = {
"library_id": library.id,
"integration_ids": tuple(integration_ids),
}
# generate inventory report csv file
sql_params: dict[str, Any] = {
"library_id": library.id,
"integration_ids": tuple(integration_ids),
}

inventory_report_file_path = self.generate_inventory_report(
session, sql_params=sql_params
)
with tempfile.NamedTemporaryFile(
delete=self.delete_attachments
) as report_zip:
zip_path = Path(report_zip.name)

# generate holds report csv file
holds_report_file_path = self.generate_holds_report(
session, sql_params=sql_params
)
with (
self.create_temp_file() as inventory_report_file,
self.create_temp_file() as holds_report_file,
):
self.generate_csv_report(
session,
csv_file=inventory_report_file,
sql_params=sql_params,
query=self.inventory_report_query(),
)

self.generate_csv_report(
session,
csv_file=holds_report_file,
sql_params=sql_params,
query=self.holds_report_query(),
)

with tempfile.NamedTemporaryFile(delete=False) as tmp:
with zipfile.ZipFile(tmp, "w", zipfile.ZIP_DEFLATED) as archive:
with zipfile.ZipFile(
zip_path, "w", zipfile.ZIP_DEFLATED
) as archive:
archive.write(
filename=holds_report_file_path,
filename=holds_report_file.name,
arcname=f"palace-holds-report-for-library-{file_name_modifier}.csv",
)
archive.write(
filename=inventory_report_file_path,
filename=inventory_report_file.name,
arcname=f"palace-inventory-report-for-library-{file_name_modifier}.csv",
)

self.log.debug(f"Zip file written to {tmp.name}")
# clean up report files now that they have been written to the zipfile
for f in [inventory_report_file_path, holds_report_file_path]:
os.remove(f)

attachments[
f"palace-inventory-and-holds-reports-for-{file_name_modifier}.zip"
] = Path(tmp.name)
self.send_email(
subject=f"Inventory and Holds Reports {current_time}",
receivers=[self.email_address],
text="",
attachments=attachments,
)

self.log.info(
f"Emailed inventory and holds reports for {library.name}({library.short_name})."
)
finally:
if self.delete_attachments:
for file_path in attachments.values():
os.remove(file_path)
self.send_email(
subject=f"Inventory and Holds Reports {current_time}",
receivers=[self.email_address],
text="",
attachments={
f"palace-inventory-and-holds-reports-for-{file_name_modifier}.zip": zip_path
},
)

def generate_inventory_report(
self, _db: Session, sql_params: dict[str, Any]
) -> str:
"""Generate an inventory csv file and return the file path"""
return self.generate_csv_report(_db, sql_params, self.inventory_report_query())
self.log.debug(f"Zip file written to {zip_path}")
self.log.info(
f"Emailed inventory and holds reports for {library.name}({library.short_name})."
)

def generate_holds_report(self, _db: Session, sql_params: dict[str, Any]) -> str:
"""Generate a holds report csv file and return the file path"""
return self.generate_csv_report(_db, sql_params, self.holds_report_query())
def create_temp_file(self) -> _TemporaryFileWrapper[str]:
return tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8")

def generate_csv_report(
self, _db: Session, sql_params: dict[str, Any], query: str
) -> str:
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as temp:
writer = csv.writer(temp, delimiter=",")
rows = _db.execute(
text(query),
sql_params,
)
writer.writerow(rows.keys())
writer.writerows(rows)

self.log.debug(f"temp file written to {temp.name}")
return temp.name
self,
_db: Session,
csv_file: _TemporaryFileWrapper[str],
sql_params: dict[str, Any],
query: str,
) -> None:
writer = csv.writer(csv_file, delimiter=",")
rows = _db.execute(
text(query),
sql_params,
)
writer.writerow(rows.keys())
writer.writerows(rows)
csv_file.flush()
self.log.debug(f"report written to {csv_file.name}")

@staticmethod
def inventory_report_query() -> str:
Expand Down
121 changes: 61 additions & 60 deletions tests/manager/celery/tasks/test_generate_inventory_and_hold_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,69 +204,70 @@ def test_job_run(

assert len(attachments) == 1
reports_zip = list(attachments.values())[0]
with zipfile.ZipFile(reports_zip, mode="r") as archive:
entry_list = archive.namelist()
assert len(entry_list) == 2
for entry in entry_list:
open_file = archive.open(entry)
if "inventory" in entry:
inventory_report_zip_entry = open_file
elif "holds" in entry:
holds_report_zip_entry = open_file

assert inventory_report_zip_entry
assert "test_library" in inventory_report_zip_entry.name
inventory_report_csv = zip_csv_entry_to_dict(inventory_report_zip_entry)

assert len(inventory_report_csv) == 1
for row in inventory_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert float(row["days_remaining_on_license"]) == float(days_remaining)
assert row["shared_active_loan_count"] == "0"
assert row["library_active_loan_count"] == "0"
assert row["remaining_loans"] == str(checkouts_left)
assert row["allowed_concurrent_users"] == str(terms_concurrency)
assert (
expiration.strftime("%Y-%m-%d %H:%M:%S.%f") in row["license_expiration"]
)

assert holds_report_zip_entry
assert "test_library" in holds_report_zip_entry.name
assert holds_report_zip_entry
holds_report_csv = zip_csv_entry_to_dict(holds_report_zip_entry)
assert len(holds_report_csv) == 1

for row in holds_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert int(row["shared_active_hold_count"]) == shared_patrons_in_hold_queue
assert int(row["library_active_hold_count"]) == 3

# clean up files
for f in attachments.values():
os.remove(f)
try:
with zipfile.ZipFile(reports_zip, mode="r") as archive:
entry_list = archive.namelist()
assert len(entry_list) == 2
with (
archive.open(entry_list[0]) as holds_report_zip_entry,
archive.open(entry_list[1]) as inventory_report_zip_entry,
):
assert inventory_report_zip_entry
assert "test_library" in inventory_report_zip_entry.name
inventory_report_csv = zip_csv_entry_to_dict(inventory_report_zip_entry)

assert len(inventory_report_csv) == 1
for row in inventory_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert float(row["days_remaining_on_license"]) == float(
days_remaining
)
assert row["shared_active_loan_count"] == "0"
assert row["library_active_loan_count"] == "0"
assert row["remaining_loans"] == str(checkouts_left)
assert row["allowed_concurrent_users"] == str(terms_concurrency)
assert (
expiration.strftime("%Y-%m-%d %H:%M:%S.%f")
in row["license_expiration"]
)

assert holds_report_zip_entry
assert "test_library" in holds_report_zip_entry.name
assert holds_report_zip_entry
holds_report_csv = zip_csv_entry_to_dict(holds_report_zip_entry)
assert len(holds_report_csv) == 1

for row in holds_report_csv:
assert row["title"] == title
assert row["author"] == author
assert row["identifier"]
assert row["language"] == language
assert row["publisher"] == publisher
assert row["audience"] == "young adult"
assert row["genres"] == "genre_a,genre_z"
assert row["format"] == edition.BOOK_MEDIUM
assert row["data_source"] == data_source
assert row["collection_name"] == collection_name
assert (
int(row["shared_active_hold_count"])
== shared_patrons_in_hold_queue
)
assert int(row["library_active_hold_count"]) == 3
finally:
os.remove(reports_zip)


def zip_csv_entry_to_dict(zip_entry: IO[bytes]):
buffer = io.BytesIO(zip_entry.read())
wrapper = io.TextIOWrapper(buffer, encoding="UTF-8")
wrapper = io.TextIOWrapper(zip_entry, encoding="UTF-8")
csv_dict = list(csv.DictReader(wrapper))
return csv_dict

Expand Down

0 comments on commit 8ae7482

Please sign in to comment.