From 0f47da260c507430565a00cd5a654d52fe1405ce Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Fri, 27 Sep 2024 10:08:35 -0300 Subject: [PATCH] Finish type hinting sqlalchemy utils file (#2087) --- pyproject.toml | 1 + .../manager/sqlalchemy/model/time_tracking.py | 6 ++-- src/palace/manager/sqlalchemy/util.py | 34 +++++++++++-------- .../sqlalchemy/model/test_licensing.py | 9 ++--- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 291503bff..831fca98f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ module = [ "palace.manager.sqlalchemy.model.integration", "palace.manager.sqlalchemy.model.library", "palace.manager.sqlalchemy.model.patron", + "palace.manager.sqlalchemy.util", "palace.manager.util.authentication_for_opds", "palace.manager.util.base64", "palace.manager.util.cache", diff --git a/src/palace/manager/sqlalchemy/model/time_tracking.py b/src/palace/manager/sqlalchemy/model/time_tracking.py index c636d0b23..56ea158ab 100644 --- a/src/palace/manager/sqlalchemy/model/time_tracking.py +++ b/src/palace/manager/sqlalchemy/model/time_tracking.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import ( Boolean, @@ -185,7 +185,9 @@ def add( "library_name": None if library else library_name, "loan_identifier": loan_identifier, } - lookup_keys = {k: v for k, v in _potential_lookup_keys.items() if v is not None} + lookup_keys: dict[str, Any] = { + k: v for k, v in _potential_lookup_keys.items() if v is not None + } additional_columns = { k: v for k, v in { diff --git a/src/palace/manager/sqlalchemy/util.py b/src/palace/manager/sqlalchemy/util.py index a09ce0b0d..4862680e3 100644 --- a/src/palace/manager/sqlalchemy/util.py +++ b/src/palace/manager/sqlalchemy/util.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from collections.abc import Generator -from typing import Literal, TypeVar +from collections.abc import Generator, Mapping +from typing import Any, Literal, TypeVar from contextlib2 import contextmanager from psycopg2._range import NumericRange @@ -44,7 +44,7 @@ def pg_advisory_lock( connection.execute(text(f"SELECT pg_advisory_unlock({lock_id});")) -def flush(db): +def flush(db: Session) -> None: """Flush the database connection unless it's known to already be flushing.""" is_flushing = False if hasattr(db, "_flushing"): @@ -63,7 +63,11 @@ def flush(db): def create( - db: Session, model: type[T], create_method="", create_method_kwargs=None, **kwargs + db: Session, + model: type[T], + create_method: str = "", + create_method_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, ) -> tuple[T, Literal[True]]: kwargs.update(create_method_kwargs or {}) created = getattr(model, create_method, model)(**kwargs) @@ -73,7 +77,11 @@ def create( def get_one( - db: Session, model: type[T], on_multiple="error", constraint=None, **kwargs + db: Session, + model: type[T], + on_multiple: Literal["interchangeable"] | Literal["error"] = "error", + constraint: Any = None, + **kwargs: Any, ) -> T | None: """Gets an object from the database based on its attributes. @@ -81,17 +89,12 @@ def get_one( `sqlalchemy.Query.filter` to limit the object that is returned. :return: object or None """ - constraint = constraint - if "constraint" in kwargs: - constraint = kwargs["constraint"] - del kwargs["constraint"] - q = db.query(model).filter_by(**kwargs) if constraint is not None: q = q.filter(constraint) try: - return q.one() + return q.one() # type: ignore[no-any-return] except MultipleResultsFound: if on_multiple == "error": raise @@ -102,14 +105,17 @@ def get_one( # This may be a sign of a problem somewhere else. A # database-level constraint might be useful. q = q.limit(1) - return q.one() + return q.one() # type: ignore[no-any-return] except NoResultFound: return None - return None def get_one_or_create( - db: Session, model: type[T], create_method="", create_method_kwargs=None, **kwargs + db: Session, + model: type[T], + create_method: str = "", + create_method_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, ) -> tuple[T, bool]: one = get_one(db, model, **kwargs) if one: diff --git a/tests/manager/sqlalchemy/model/test_licensing.py b/tests/manager/sqlalchemy/model/test_licensing.py index 973911acf..0670451aa 100644 --- a/tests/manager/sqlalchemy/model/test_licensing.py +++ b/tests/manager/sqlalchemy/model/test_licensing.py @@ -1,6 +1,7 @@ import datetime import json from collections.abc import Callable +from typing import Any from unittest.mock import MagicMock, PropertyMock import pytest @@ -226,15 +227,15 @@ def test_uniqueness_constraint( # You can't create two DeliveryMechanisms with the same values # for content_type and drm_scheme. - with_drm_args = dict(content_type="type1", drm_scheme="scheme1") - without_drm_args = dict(content_type="type1", drm_scheme=None) - with_drm = create(session, dm, **with_drm_args) + with_drm_args: dict[str, Any] = dict(content_type="type1", drm_scheme="scheme1") + create(session, dm, **with_drm_args) pytest.raises(IntegrityError, create, session, dm, **with_drm_args) session.rollback() # You can't create two DeliveryMechanisms with the same value # for content_type and a null value for drm_scheme. - without_drm = create(session, dm, **without_drm_args) + without_drm_args: dict[str, Any] = dict(content_type="type1", drm_scheme=None) + create(session, dm, **without_drm_args) pytest.raises(IntegrityError, create, session, dm, **without_drm_args) session.rollback()