Skip to content

Commit

Permalink
Replace all query.get with session.get to stop deprecation warning in…
Browse files Browse the repository at this point in the history
… SQLAlchemy 2.0
  • Loading branch information
jace committed Jan 3, 2024
1 parent 46a79fe commit 9fbdfa3
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 39 deletions.
10 changes: 5 additions & 5 deletions funnel/cli/geodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def load_country_info(filename: str) -> None:
GeoCountryInfo.query.all() # Load everything into session cache
for item in countryinfo:
if item.geonameid:
ci = GeoCountryInfo.query.get(int(item.geonameid))
ci = db.session.get(GeoCountryInfo, int(item.geonameid))
if ci is None:
ci = GeoCountryInfo(geonameid=int(item.geonameid))
db.session.add(ci)
Expand Down Expand Up @@ -276,7 +276,7 @@ def load_geonames(filename: str) -> None:

for item in rich.progress.track(geonames):
if item.geonameid:
gn = GeoName.query.get(int(item.geonameid))
gn = db.session.get(GeoName, int(item.geonameid))
if gn is None:
gn = GeoName(geonameid=int(item.geonameid))
db.session.add(gn)
Expand Down Expand Up @@ -335,7 +335,7 @@ def load_alt_names(filename: str) -> None:

for item in rich.progress.track(altnames):
if item.geonameid:
rec = GeoAltName.query.get(int(item.id))
rec = db.session.get(GeoAltName, int(item.id))
if rec is None:
rec = GeoAltName(id=int(item.id))
db.session.add(rec)
Expand Down Expand Up @@ -368,7 +368,7 @@ def load_admin1_codes(filename: str) -> None:
GeoAdmin1Code.query.all() # Load all data into session cache for faster lookup
for item in rich.progress.track(admincodes):
if item.geonameid:
rec = GeoAdmin1Code.query.get(item.geonameid)
rec = db.session.get(GeoAdmin1Code, item.geonameid)
if rec is None:
rec = GeoAdmin1Code(geonameid=int(item.geonameid))
db.session.add(rec)
Expand Down Expand Up @@ -397,7 +397,7 @@ def load_admin2_codes(filename: str) -> None:
GeoAdmin2Code.query.all() # Load all data into session cache for faster lookup
for item in rich.progress.track(admincodes):
if item.geonameid:
rec = GeoAdmin2Code.query.get(item.geonameid)
rec = db.session.get(GeoAdmin2Code, item.geonameid)
if rec is None:
rec = GeoAdmin2Code(geonameid=int(item.geonameid))
db.session.add(rec)
Expand Down
2 changes: 1 addition & 1 deletion funnel/models/geoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def related_geonames(self) -> dict[str, GeoName]:
and self.country
and self.country.continent
):
continent = GeoName.query.get(continent_codes[self.country.continent])
continent = db.session.get(GeoName, continent_codes[self.country.continent])
if continent:
related['continent'] = continent

Expand Down
2 changes: 1 addition & 1 deletion funnel/models/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def recipients_iter(self) -> Iterator[MailerRecipient]:
.all()
]
for rid in ids:
recipient = MailerRecipient.query.get(rid)
recipient = db.session.get(MailerRecipient, rid)
if recipient:
yield recipient

Expand Down
10 changes: 6 additions & 4 deletions funnel/models/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ def dispatch(self) -> Generator[NotificationRecipient, None, None]:

# Since this query uses SQLAlchemy's session cache, we don't have to
# bother with a local cache for the first case.
existing_notification = NotificationRecipient.query.get(
(account.id, self.eventid)
existing_notification = db.session.get(
NotificationRecipient, (account.id, self.eventid)
)
if existing_notification is None:
recipient = NotificationRecipient(
Expand Down Expand Up @@ -1163,7 +1163,7 @@ def rolledup_fragments(self) -> Query | None:
@classmethod
def get_for(cls, user: Account, eventid_b58: str) -> NotificationRecipient | None:
"""Retrieve a :class:`UserNotification` using SQLAlchemy session cache."""
return cls.query.get((user.id, uuid_from_base58(eventid_b58)))
return db.session.get(cls, (user.id, uuid_from_base58(eventid_b58)))

@classmethod
def web_notifications_for(
Expand Down Expand Up @@ -1199,7 +1199,9 @@ def migrate_account(cls, old_account: Account, new_account: Account) -> None:
for notification_recipient in cls.query.filter_by(
recipient_id=old_account.id
).all():
existing = cls.query.get((new_account.id, notification_recipient.eventid))
existing = db.session.get(
cls, (new_account.id, notification_recipient.eventid)
)
# TODO: Instead of dropping old_user's dupe notifications, check which of
# the two has a higher priority role and keep that. This may not be possible
# if the two copies are for different notifications under the same eventid.
Expand Down
2 changes: 1 addition & 1 deletion funnel/models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def add(
account = project.account
if name is None:
name = project.name
redirect = cls.query.get((account.id, name))
redirect = db.session.get(cls, (account.id, name))
if redirect is None:
redirect = cls(account=account, name=name, project=project)
db.session.add(redirect)
Expand Down
2 changes: 1 addition & 1 deletion funnel/models/rsvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def get_for(
cls, project: Project, account: Account | None, create: bool = False
) -> Self | None:
if account is not None:
result = cls.query.get((project.id, account.id))
result = db.session.get(cls, (project.id, account.id))
if not result and create:
result = cls(project=project, participant=account)
db.session.add(result)
Expand Down
2 changes: 1 addition & 1 deletion funnel/models/sync_ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def roles_for(
if actor is not None:
if actor == self.participant:
roles.add('member')
cx = ContactExchange.query.get((actor.id, self.id))
cx = db.session.get(ContactExchange, (actor.id, self.id))
if cx is not None:
roles.add('scanner')
return roles
Expand Down
6 changes: 3 additions & 3 deletions funnel/views/api/geoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from coaster.views import requestargs

from ... import app
from ...models import GeoName
from ...models import GeoName, db
from ...typing import ReturnView


Expand All @@ -19,7 +19,7 @@ def geo_get_by_name(
) -> ReturnView:
"""Get a geoname record given a single URL stub name or geoname id."""
if name.isdigit():
geoname = GeoName.query.get(int(name))
geoname = db.session.get(GeoName, int(name))
else:
geoname = GeoName.get(name)
return (
Expand All @@ -43,7 +43,7 @@ def geo_get_by_names(
geonames = []
for n in name:
if n.isdigit():
geoname = GeoName.query.get(int(n))
geoname = db.session.get(GeoName, int(n))
else:
geoname = GeoName.get(n)
if geoname:
Expand Down
6 changes: 3 additions & 3 deletions funnel/views/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co:
@rqjob()
def import_tickets(ticket_client_id: int) -> None:
"""Import tickets from Boxoffice."""
ticket_client = TicketClient.query.get(ticket_client_id)
ticket_client = db.session.get(TicketClient, ticket_client_id)
if ticket_client is not None:
if ticket_client.name.lower() == 'explara':
ticket_list = ExplaraAPI(
Expand All @@ -79,7 +79,7 @@ def import_tickets(ticket_client_id: int) -> None:
@rqjob()
def tag_locations(project_id: int) -> None:
"""Tag a project with geoname locations. This is legacy code pending a rewrite."""
project = Project.query.get(project_id)
project = db.session.get(Project, project_id)
if project is None:
return
if not project.location:
Expand Down Expand Up @@ -116,7 +116,7 @@ def tag_locations(project_id: int) -> None:
project.parsed_location = {'tokens': tokens}

for locdata in geonames.values():
loc = ProjectLocation.query.get((project_id, locdata['geonameid']))
loc = db.session.get(ProjectLocation, (project_id, locdata['geonameid']))
if loc is None:
loc = ProjectLocation(project=project, geonameid=locdata['geonameid'])
db.session.add(loc)
Expand Down
2 changes: 1 addition & 1 deletion funnel/views/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_draft(self, obj: ModelUuidProtocol | None = None) -> Draft | None:
`obj` is needed in case of multi-model views.
"""
obj = obj if obj is not None else self.obj
return Draft.query.get((self.model.__tablename__, obj.uuid))
return db.session.get(Draft, (self.model.__tablename__, obj.uuid))

def delete_draft(self, obj=None):
"""Delete draft for `obj`, or `self.obj` if `obj` is `None`."""
Expand Down
8 changes: 5 additions & 3 deletions funnel/views/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def transport_worker_wrapper(
def inner(notification_recipient_ids: Sequence[tuple[int, UUID]]) -> None:
"""Convert a notification id into an object for worker to process."""
queue = [
NotificationRecipient.query.get(identity)
db.session.get(NotificationRecipient, identity)
for identity in notification_recipient_ids
]
for notification_recipient in queue:
Expand Down Expand Up @@ -627,7 +627,9 @@ def dispatch_transport_sms(
@rqjob()
def dispatch_notification_job(eventid: UUID, notification_ids: Sequence[UUID]) -> None:
"""Process :class:`Notification` into batches of :class:`UserNotification`."""
notifications = [Notification.query.get((eventid, nid)) for nid in notification_ids]
notifications = [
db.session.get(Notification, (eventid, nid)) for nid in notification_ids
]

# Dispatch, creating batches of DISPATCH_BATCH_SIZE each
for notification in notifications:
Expand Down Expand Up @@ -655,7 +657,7 @@ def dispatch_notification_recipients_job(
"""Process notifications for users and enqueue transport delivery."""
# TODO: Can this be a single query instead of a loop of queries?
queue = [
NotificationRecipient.query.get(identity)
db.session.get(NotificationRecipient, identity)
for identity in notification_recipient_ids
]
transport_batch: dict[str, list[tuple[int, UUID]]] = defaultdict(list)
Expand Down
27 changes: 19 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
from pprint import saferepr
from textwrap import indent
from types import MethodType, ModuleType, SimpleNamespace
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, get_type_hints
from typing import (
TYPE_CHECKING,
Any,
NamedTuple,
Protocol,
get_type_hints,
runtime_checkable,
)
from unittest.mock import patch

import flask
Expand All @@ -41,7 +48,7 @@
from rich.syntax import Syntax
from rich.text import Text
from sqlalchemy import event
from sqlalchemy.orm import Session as DatabaseSessionClass
from sqlalchemy.orm import Session as DatabaseSessionClass, scoped_session
from werkzeug import run_simple
from werkzeug.test import TestResponse

Expand Down Expand Up @@ -127,12 +134,15 @@ def pytest_runtest_call(item: pytest.Function) -> None:
# get_type_hints may fail on Python <3.10 because pytest-bdd appears to have
# `dict[str, str]` as a type somewhere, and builtin type subscripting isn't
# supported yet
warnings.warn(
f"Type annotations could not be retrieved for {item.obj!r}",
warnings.warn( # noqa: B028
f"Type annotations could not be retrieved for {item.obj.__qualname__}",
RuntimeWarning,
stacklevel=1,
)
return
except NameError as exc:
pytest.fail(
f"{item.obj.__qualname__} has an unknown annotation for {exc.name}. Is it imported under TYPE_CHECKING?"
)

for attr, type_ in annotations.items():
if attr in item.funcargs:
Expand Down Expand Up @@ -937,7 +947,7 @@ def drop_tables():
@pytest.fixture()
def db_session_truncate(
funnel, app, database, app_context
) -> Iterator[DatabaseSessionClass]:
) -> Iterator[DatabaseSessionClass | scoped_session]:
"""Empty the database after each use of the fixture."""
yield database.session
sa_orm.close_all_sessions()
Expand Down Expand Up @@ -996,7 +1006,7 @@ def get_bind(
@pytest.fixture()
def db_session_rollback(
funnel, app, database, app_context
) -> Iterator[DatabaseSessionClass]:
) -> Iterator[DatabaseSessionClass | scoped_session]:
"""Create a nested transaction for the test and rollback after."""
original_session = database.session

Expand Down Expand Up @@ -1036,7 +1046,7 @@ def db_session_rollback(


@pytest.fixture()
def db_session(request) -> DatabaseSessionClass:
def db_session(request) -> DatabaseSessionClass | scoped_session:
"""
Database session fixture.
Expand Down Expand Up @@ -1148,6 +1158,7 @@ def csrf_token(app, client) -> str:
return token


@runtime_checkable
class LoginFixtureProtocol(Protocol):
def as_(self, user: funnel_models.User) -> None:
...
Expand Down
43 changes: 36 additions & 7 deletions tests/integration/views/label_views_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
"""Test Label views."""

import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import scoped_session

from funnel import models

from ...conftest import LoginFixtureProtocol


@pytest.mark.dbcommit()
def test_manage_labels_view(
app, client, login, new_project, new_user, new_label, new_main_label
app: Flask,
client: FlaskClient,
login: LoginFixtureProtocol,
new_project: models.Project,
new_user: models.User,
new_label: models.Label,
new_main_label: models.Label,
) -> None:
login.as_(new_user)
resp = client.get(new_project.url_for('labels'))
Expand All @@ -17,26 +28,44 @@ def test_manage_labels_view(


@pytest.mark.dbcommit()
def test_edit_option_label_view(app, client, login, new_user, new_main_label) -> None:
def test_edit_option_label_view(
app: Flask,
client: FlaskClient,
login: LoginFixtureProtocol,
new_user: models.User,
new_main_label: models.Label,
) -> None:
login.as_(new_user)
opt_label = new_main_label.options[0]
resp = client.post(opt_label.url_for('edit'), follow_redirects=True)
assert "Manage labels" in resp.data.decode('utf-8')
assert "Only main labels can be edited" in resp.data.decode('utf-8')


@pytest.mark.xfail(reason="Broken by Flask-SQLAlchemy 3.0, unclear why") # FIXME
def test_main_label_delete(client, login, new_user, new_label) -> None:
@pytest.mark.xfail(reason="Broken after Flask-SQLAlchemy 3.0, unclear why") # FIXME
def test_main_label_delete(
db_session: scoped_session,
client: FlaskClient,
login: LoginFixtureProtocol,
new_user: models.User,
new_label: models.Label,
) -> None:
login.as_(new_user)
resp = client.post(new_label.url_for('delete'), follow_redirects=True)
assert "Manage labels" in resp.data.decode('utf-8')
assert "The label has been deleted" in resp.data.decode('utf-8')
label = models.Label.query.get(new_label.id)
label = db_session.get(models.Label, new_label.id)
assert label is None


@pytest.mark.xfail(reason="Broken after Flask-SQLAlchemy 3.0, unclear why") # FIXME
def test_optioned_label_delete(client, login, new_user, new_main_label) -> None:
def test_optioned_label_delete(
db_session: scoped_session,
client: FlaskClient,
login: LoginFixtureProtocol,
new_user: models.User,
new_main_label: models.Label,
) -> None:
login.as_(new_user)
label_a1 = new_main_label.options[0]
label_a2 = new_main_label.options[1]
Expand All @@ -45,7 +74,7 @@ def test_optioned_label_delete(client, login, new_user, new_main_label) -> None:
resp = client.post(new_main_label.url_for('delete'), follow_redirects=True)
assert "Manage labels" in resp.data.decode('utf-8')
assert "The label has been deleted" in resp.data.decode('utf-8')
mlabel = models.Label.query.get(new_main_label.id)
mlabel = db_session.get(models.Label, new_main_label.id)
assert mlabel is None

# so the option labels should have been deleted as well
Expand Down

0 comments on commit 9fbdfa3

Please sign in to comment.