Skip to content

Commit

Permalink
Finish type hinting sqlalchemy utils file (#2087)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen authored Sep 27, 2024
1 parent 344383b commit 0f47da2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ module = [
"palace.manager.sqlalchemy.model.integration",
"palace.manager.sqlalchemy.model.library",
"palace.manager.sqlalchemy.model.patron",
"palace.manager.sqlalchemy.util",
"palace.manager.util.authentication_for_opds",
"palace.manager.util.base64",
"palace.manager.util.cache",
Expand Down
6 changes: 4 additions & 2 deletions src/palace/manager/sqlalchemy/model/time_tracking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import (
Boolean,
Expand Down Expand Up @@ -185,7 +185,9 @@ def add(
"library_name": None if library else library_name,
"loan_identifier": loan_identifier,
}
lookup_keys = {k: v for k, v in _potential_lookup_keys.items() if v is not None}
lookup_keys: dict[str, Any] = {
k: v for k, v in _potential_lookup_keys.items() if v is not None
}
additional_columns = {
k: v
for k, v in {
Expand Down
34 changes: 20 additions & 14 deletions src/palace/manager/sqlalchemy/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import logging
from collections.abc import Generator
from typing import Literal, TypeVar
from collections.abc import Generator, Mapping
from typing import Any, Literal, TypeVar

from contextlib2 import contextmanager
from psycopg2._range import NumericRange
Expand Down Expand Up @@ -44,7 +44,7 @@ def pg_advisory_lock(
connection.execute(text(f"SELECT pg_advisory_unlock({lock_id});"))


def flush(db):
def flush(db: Session) -> None:
"""Flush the database connection unless it's known to already be flushing."""
is_flushing = False
if hasattr(db, "_flushing"):
Expand All @@ -63,7 +63,11 @@ def flush(db):


def create(
db: Session, model: type[T], create_method="", create_method_kwargs=None, **kwargs
db: Session,
model: type[T],
create_method: str = "",
create_method_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> tuple[T, Literal[True]]:
kwargs.update(create_method_kwargs or {})
created = getattr(model, create_method, model)(**kwargs)
Expand All @@ -73,25 +77,24 @@ def create(


def get_one(
db: Session, model: type[T], on_multiple="error", constraint=None, **kwargs
db: Session,
model: type[T],
on_multiple: Literal["interchangeable"] | Literal["error"] = "error",
constraint: Any = None,
**kwargs: Any,
) -> T | None:
"""Gets an object from the database based on its attributes.
:param constraint: A single clause that can be passed into
`sqlalchemy.Query.filter` to limit the object that is returned.
:return: object or None
"""
constraint = constraint
if "constraint" in kwargs:
constraint = kwargs["constraint"]
del kwargs["constraint"]

q = db.query(model).filter_by(**kwargs)
if constraint is not None:
q = q.filter(constraint)

try:
return q.one()
return q.one() # type: ignore[no-any-return]
except MultipleResultsFound:
if on_multiple == "error":
raise
Expand All @@ -102,14 +105,17 @@ def get_one(
# This may be a sign of a problem somewhere else. A
# database-level constraint might be useful.
q = q.limit(1)
return q.one()
return q.one() # type: ignore[no-any-return]
except NoResultFound:
return None
return None


def get_one_or_create(
db: Session, model: type[T], create_method="", create_method_kwargs=None, **kwargs
db: Session,
model: type[T],
create_method: str = "",
create_method_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> tuple[T, bool]:
one = get_one(db, model, **kwargs)
if one:
Expand Down
9 changes: 5 additions & 4 deletions tests/manager/sqlalchemy/model/test_licensing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import json
from collections.abc import Callable
from typing import Any
from unittest.mock import MagicMock, PropertyMock

import pytest
Expand Down Expand Up @@ -226,15 +227,15 @@ def test_uniqueness_constraint(

# You can't create two DeliveryMechanisms with the same values
# for content_type and drm_scheme.
with_drm_args = dict(content_type="type1", drm_scheme="scheme1")
without_drm_args = dict(content_type="type1", drm_scheme=None)
with_drm = create(session, dm, **with_drm_args)
with_drm_args: dict[str, Any] = dict(content_type="type1", drm_scheme="scheme1")
create(session, dm, **with_drm_args)
pytest.raises(IntegrityError, create, session, dm, **with_drm_args)
session.rollback()

# You can't create two DeliveryMechanisms with the same value
# for content_type and a null value for drm_scheme.
without_drm = create(session, dm, **without_drm_args)
without_drm_args: dict[str, Any] = dict(content_type="type1", drm_scheme=None)
create(session, dm, **without_drm_args)
pytest.raises(IntegrityError, create, session, dm, **without_drm_args)
session.rollback()

Expand Down

0 comments on commit 0f47da2

Please sign in to comment.