From 9fbdfa33e0094a7f87bc279f21411af62da8d32e Mon Sep 17 00:00:00 2001 From: Kiran Jonnalagadda Date: Wed, 3 Jan 2024 23:35:51 +0530 Subject: [PATCH] Replace all query.get with session.get to stop deprecation warning in SQLAlchemy 2.0 --- funnel/cli/geodata.py | 10 ++--- funnel/models/geoname.py | 2 +- funnel/models/mailer.py | 2 +- funnel/models/notification.py | 10 +++-- funnel/models/project.py | 2 +- funnel/models/rsvp.py | 2 +- funnel/models/sync_ticket.py | 2 +- funnel/views/api/geoname.py | 6 +-- funnel/views/jobs.py | 6 +-- funnel/views/mixins.py | 2 +- funnel/views/notification.py | 8 ++-- tests/conftest.py | 27 +++++++++---- tests/integration/views/label_views_test.py | 43 +++++++++++++++++---- 13 files changed, 83 insertions(+), 39 deletions(-) diff --git a/funnel/cli/geodata.py b/funnel/cli/geodata.py index 978f03850..839d29b1b 100644 --- a/funnel/cli/geodata.py +++ b/funnel/cli/geodata.py @@ -167,7 +167,7 @@ def load_country_info(filename: str) -> None: GeoCountryInfo.query.all() # Load everything into session cache for item in countryinfo: if item.geonameid: - ci = GeoCountryInfo.query.get(int(item.geonameid)) + ci = db.session.get(GeoCountryInfo, int(item.geonameid)) if ci is None: ci = GeoCountryInfo(geonameid=int(item.geonameid)) db.session.add(ci) @@ -276,7 +276,7 @@ def load_geonames(filename: str) -> None: for item in rich.progress.track(geonames): if item.geonameid: - gn = GeoName.query.get(int(item.geonameid)) + gn = db.session.get(GeoName, int(item.geonameid)) if gn is None: gn = GeoName(geonameid=int(item.geonameid)) db.session.add(gn) @@ -335,7 +335,7 @@ def load_alt_names(filename: str) -> None: for item in rich.progress.track(altnames): if item.geonameid: - rec = GeoAltName.query.get(int(item.id)) + rec = db.session.get(GeoAltName, int(item.id)) if rec is None: rec = GeoAltName(id=int(item.id)) db.session.add(rec) @@ -368,7 +368,7 @@ def load_admin1_codes(filename: str) -> None: GeoAdmin1Code.query.all() # Load all data into session cache for faster lookup for item in rich.progress.track(admincodes): if item.geonameid: - rec = GeoAdmin1Code.query.get(item.geonameid) + rec = db.session.get(GeoAdmin1Code, item.geonameid) if rec is None: rec = GeoAdmin1Code(geonameid=int(item.geonameid)) db.session.add(rec) @@ -397,7 +397,7 @@ def load_admin2_codes(filename: str) -> None: GeoAdmin2Code.query.all() # Load all data into session cache for faster lookup for item in rich.progress.track(admincodes): if item.geonameid: - rec = GeoAdmin2Code.query.get(item.geonameid) + rec = db.session.get(GeoAdmin2Code, item.geonameid) if rec is None: rec = GeoAdmin2Code(geonameid=int(item.geonameid)) db.session.add(rec) diff --git a/funnel/models/geoname.py b/funnel/models/geoname.py index 3c2b72a3c..ace2cd2fb 100644 --- a/funnel/models/geoname.py +++ b/funnel/models/geoname.py @@ -391,7 +391,7 @@ def related_geonames(self) -> dict[str, GeoName]: and self.country and self.country.continent ): - continent = GeoName.query.get(continent_codes[self.country.continent]) + continent = db.session.get(GeoName, continent_codes[self.country.continent]) if continent: related['continent'] = continent diff --git a/funnel/models/mailer.py b/funnel/models/mailer.py index dcefe9cf5..c60e7f6d7 100644 --- a/funnel/models/mailer.py +++ b/funnel/models/mailer.py @@ -156,7 +156,7 @@ def recipients_iter(self) -> Iterator[MailerRecipient]: .all() ] for rid in ids: - recipient = MailerRecipient.query.get(rid) + recipient = db.session.get(MailerRecipient, rid) if recipient: yield recipient diff --git a/funnel/models/notification.py b/funnel/models/notification.py index 7715e96ff..498bfd916 100644 --- a/funnel/models/notification.py +++ b/funnel/models/notification.py @@ -693,8 +693,8 @@ def dispatch(self) -> Generator[NotificationRecipient, None, None]: # Since this query uses SQLAlchemy's session cache, we don't have to # bother with a local cache for the first case. - existing_notification = NotificationRecipient.query.get( - (account.id, self.eventid) + existing_notification = db.session.get( + NotificationRecipient, (account.id, self.eventid) ) if existing_notification is None: recipient = NotificationRecipient( @@ -1163,7 +1163,7 @@ def rolledup_fragments(self) -> Query | None: @classmethod def get_for(cls, user: Account, eventid_b58: str) -> NotificationRecipient | None: """Retrieve a :class:`UserNotification` using SQLAlchemy session cache.""" - return cls.query.get((user.id, uuid_from_base58(eventid_b58))) + return db.session.get(cls, (user.id, uuid_from_base58(eventid_b58))) @classmethod def web_notifications_for( @@ -1199,7 +1199,9 @@ def migrate_account(cls, old_account: Account, new_account: Account) -> None: for notification_recipient in cls.query.filter_by( recipient_id=old_account.id ).all(): - existing = cls.query.get((new_account.id, notification_recipient.eventid)) + existing = db.session.get( + cls, (new_account.id, notification_recipient.eventid) + ) # TODO: Instead of dropping old_user's dupe notifications, check which of # the two has a higher priority role and keep that. This may not be possible # if the two copies are for different notifications under the same eventid. diff --git a/funnel/models/project.py b/funnel/models/project.py index 1549420ae..6889dbf98 100644 --- a/funnel/models/project.py +++ b/funnel/models/project.py @@ -1509,7 +1509,7 @@ def add( account = project.account if name is None: name = project.name - redirect = cls.query.get((account.id, name)) + redirect = db.session.get(cls, (account.id, name)) if redirect is None: redirect = cls(account=account, name=name, project=project) db.session.add(redirect) diff --git a/funnel/models/rsvp.py b/funnel/models/rsvp.py index 250920a95..9480391b2 100644 --- a/funnel/models/rsvp.py +++ b/funnel/models/rsvp.py @@ -193,7 +193,7 @@ def get_for( cls, project: Project, account: Account | None, create: bool = False ) -> Self | None: if account is not None: - result = cls.query.get((project.id, account.id)) + result = db.session.get(cls, (project.id, account.id)) if not result and create: result = cls(project=project, participant=account) db.session.add(result) diff --git a/funnel/models/sync_ticket.py b/funnel/models/sync_ticket.py index 9739db7a5..d2f31eb99 100644 --- a/funnel/models/sync_ticket.py +++ b/funnel/models/sync_ticket.py @@ -302,7 +302,7 @@ def roles_for( if actor is not None: if actor == self.participant: roles.add('member') - cx = ContactExchange.query.get((actor.id, self.id)) + cx = db.session.get(ContactExchange, (actor.id, self.id)) if cx is not None: roles.add('scanner') return roles diff --git a/funnel/views/api/geoname.py b/funnel/views/api/geoname.py index dc993a497..9d2db4715 100644 --- a/funnel/views/api/geoname.py +++ b/funnel/views/api/geoname.py @@ -8,7 +8,7 @@ from coaster.views import requestargs from ... import app -from ...models import GeoName +from ...models import GeoName, db from ...typing import ReturnView @@ -19,7 +19,7 @@ def geo_get_by_name( ) -> ReturnView: """Get a geoname record given a single URL stub name or geoname id.""" if name.isdigit(): - geoname = GeoName.query.get(int(name)) + geoname = db.session.get(GeoName, int(name)) else: geoname = GeoName.get(name) return ( @@ -43,7 +43,7 @@ def geo_get_by_names( geonames = [] for n in name: if n.isdigit(): - geoname = GeoName.query.get(int(n)) + geoname = db.session.get(GeoName, int(n)) else: geoname = GeoName.get(n) if geoname: diff --git a/funnel/views/jobs.py b/funnel/views/jobs.py index 4a6f67d0f..78bae910f 100644 --- a/funnel/views/jobs.py +++ b/funnel/views/jobs.py @@ -61,7 +61,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: @rqjob() def import_tickets(ticket_client_id: int) -> None: """Import tickets from Boxoffice.""" - ticket_client = TicketClient.query.get(ticket_client_id) + ticket_client = db.session.get(TicketClient, ticket_client_id) if ticket_client is not None: if ticket_client.name.lower() == 'explara': ticket_list = ExplaraAPI( @@ -79,7 +79,7 @@ def import_tickets(ticket_client_id: int) -> None: @rqjob() def tag_locations(project_id: int) -> None: """Tag a project with geoname locations. This is legacy code pending a rewrite.""" - project = Project.query.get(project_id) + project = db.session.get(Project, project_id) if project is None: return if not project.location: @@ -116,7 +116,7 @@ def tag_locations(project_id: int) -> None: project.parsed_location = {'tokens': tokens} for locdata in geonames.values(): - loc = ProjectLocation.query.get((project_id, locdata['geonameid'])) + loc = db.session.get(ProjectLocation, (project_id, locdata['geonameid'])) if loc is None: loc = ProjectLocation(project=project, geonameid=locdata['geonameid']) db.session.add(loc) diff --git a/funnel/views/mixins.py b/funnel/views/mixins.py index ca836f79a..1cbaadf8c 100644 --- a/funnel/views/mixins.py +++ b/funnel/views/mixins.py @@ -140,7 +140,7 @@ def get_draft(self, obj: ModelUuidProtocol | None = None) -> Draft | None: `obj` is needed in case of multi-model views. """ obj = obj if obj is not None else self.obj - return Draft.query.get((self.model.__tablename__, obj.uuid)) + return db.session.get(Draft, (self.model.__tablename__, obj.uuid)) def delete_draft(self, obj=None): """Delete draft for `obj`, or `self.obj` if `obj` is `None`.""" diff --git a/funnel/views/notification.py b/funnel/views/notification.py index 72a529365..57e90fa29 100644 --- a/funnel/views/notification.py +++ b/funnel/views/notification.py @@ -509,7 +509,7 @@ def transport_worker_wrapper( def inner(notification_recipient_ids: Sequence[tuple[int, UUID]]) -> None: """Convert a notification id into an object for worker to process.""" queue = [ - NotificationRecipient.query.get(identity) + db.session.get(NotificationRecipient, identity) for identity in notification_recipient_ids ] for notification_recipient in queue: @@ -627,7 +627,9 @@ def dispatch_transport_sms( @rqjob() def dispatch_notification_job(eventid: UUID, notification_ids: Sequence[UUID]) -> None: """Process :class:`Notification` into batches of :class:`UserNotification`.""" - notifications = [Notification.query.get((eventid, nid)) for nid in notification_ids] + notifications = [ + db.session.get(Notification, (eventid, nid)) for nid in notification_ids + ] # Dispatch, creating batches of DISPATCH_BATCH_SIZE each for notification in notifications: @@ -655,7 +657,7 @@ def dispatch_notification_recipients_job( """Process notifications for users and enqueue transport delivery.""" # TODO: Can this be a single query instead of a loop of queries? queue = [ - NotificationRecipient.query.get(identity) + db.session.get(NotificationRecipient, identity) for identity in notification_recipient_ids ] transport_batch: dict[str, list[tuple[int, UUID]]] = defaultdict(list) diff --git a/tests/conftest.py b/tests/conftest.py index a0c6180da..33e0fe093 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,14 @@ from pprint import saferepr from textwrap import indent from types import MethodType, ModuleType, SimpleNamespace -from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + NamedTuple, + Protocol, + get_type_hints, + runtime_checkable, +) from unittest.mock import patch import flask @@ -41,7 +48,7 @@ from rich.syntax import Syntax from rich.text import Text from sqlalchemy import event -from sqlalchemy.orm import Session as DatabaseSessionClass +from sqlalchemy.orm import Session as DatabaseSessionClass, scoped_session from werkzeug import run_simple from werkzeug.test import TestResponse @@ -127,12 +134,15 @@ def pytest_runtest_call(item: pytest.Function) -> None: # get_type_hints may fail on Python <3.10 because pytest-bdd appears to have # `dict[str, str]` as a type somewhere, and builtin type subscripting isn't # supported yet - warnings.warn( - f"Type annotations could not be retrieved for {item.obj!r}", + warnings.warn( # noqa: B028 + f"Type annotations could not be retrieved for {item.obj.__qualname__}", RuntimeWarning, - stacklevel=1, ) return + except NameError as exc: + pytest.fail( + f"{item.obj.__qualname__} has an unknown annotation for {exc.name}. Is it imported under TYPE_CHECKING?" + ) for attr, type_ in annotations.items(): if attr in item.funcargs: @@ -937,7 +947,7 @@ def drop_tables(): @pytest.fixture() def db_session_truncate( funnel, app, database, app_context -) -> Iterator[DatabaseSessionClass]: +) -> Iterator[DatabaseSessionClass | scoped_session]: """Empty the database after each use of the fixture.""" yield database.session sa_orm.close_all_sessions() @@ -996,7 +1006,7 @@ def get_bind( @pytest.fixture() def db_session_rollback( funnel, app, database, app_context -) -> Iterator[DatabaseSessionClass]: +) -> Iterator[DatabaseSessionClass | scoped_session]: """Create a nested transaction for the test and rollback after.""" original_session = database.session @@ -1036,7 +1046,7 @@ def db_session_rollback( @pytest.fixture() -def db_session(request) -> DatabaseSessionClass: +def db_session(request) -> DatabaseSessionClass | scoped_session: """ Database session fixture. @@ -1148,6 +1158,7 @@ def csrf_token(app, client) -> str: return token +@runtime_checkable class LoginFixtureProtocol(Protocol): def as_(self, user: funnel_models.User) -> None: ... diff --git a/tests/integration/views/label_views_test.py b/tests/integration/views/label_views_test.py index 0f83bf7b3..a00c27831 100644 --- a/tests/integration/views/label_views_test.py +++ b/tests/integration/views/label_views_test.py @@ -1,13 +1,24 @@ """Test Label views.""" import pytest +from flask import Flask +from flask.testing import FlaskClient +from sqlalchemy.orm import scoped_session from funnel import models +from ...conftest import LoginFixtureProtocol + @pytest.mark.dbcommit() def test_manage_labels_view( - app, client, login, new_project, new_user, new_label, new_main_label + app: Flask, + client: FlaskClient, + login: LoginFixtureProtocol, + new_project: models.Project, + new_user: models.User, + new_label: models.Label, + new_main_label: models.Label, ) -> None: login.as_(new_user) resp = client.get(new_project.url_for('labels')) @@ -17,7 +28,13 @@ def test_manage_labels_view( @pytest.mark.dbcommit() -def test_edit_option_label_view(app, client, login, new_user, new_main_label) -> None: +def test_edit_option_label_view( + app: Flask, + client: FlaskClient, + login: LoginFixtureProtocol, + new_user: models.User, + new_main_label: models.Label, +) -> None: login.as_(new_user) opt_label = new_main_label.options[0] resp = client.post(opt_label.url_for('edit'), follow_redirects=True) @@ -25,18 +42,30 @@ def test_edit_option_label_view(app, client, login, new_user, new_main_label) -> assert "Only main labels can be edited" in resp.data.decode('utf-8') -@pytest.mark.xfail(reason="Broken by Flask-SQLAlchemy 3.0, unclear why") # FIXME -def test_main_label_delete(client, login, new_user, new_label) -> None: +@pytest.mark.xfail(reason="Broken after Flask-SQLAlchemy 3.0, unclear why") # FIXME +def test_main_label_delete( + db_session: scoped_session, + client: FlaskClient, + login: LoginFixtureProtocol, + new_user: models.User, + new_label: models.Label, +) -> None: login.as_(new_user) resp = client.post(new_label.url_for('delete'), follow_redirects=True) assert "Manage labels" in resp.data.decode('utf-8') assert "The label has been deleted" in resp.data.decode('utf-8') - label = models.Label.query.get(new_label.id) + label = db_session.get(models.Label, new_label.id) assert label is None @pytest.mark.xfail(reason="Broken after Flask-SQLAlchemy 3.0, unclear why") # FIXME -def test_optioned_label_delete(client, login, new_user, new_main_label) -> None: +def test_optioned_label_delete( + db_session: scoped_session, + client: FlaskClient, + login: LoginFixtureProtocol, + new_user: models.User, + new_main_label: models.Label, +) -> None: login.as_(new_user) label_a1 = new_main_label.options[0] label_a2 = new_main_label.options[1] @@ -45,7 +74,7 @@ def test_optioned_label_delete(client, login, new_user, new_main_label) -> None: resp = client.post(new_main_label.url_for('delete'), follow_redirects=True) assert "Manage labels" in resp.data.decode('utf-8') assert "The label has been deleted" in resp.data.decode('utf-8') - mlabel = models.Label.query.get(new_main_label.id) + mlabel = db_session.get(models.Label, new_main_label.id) assert mlabel is None # so the option labels should have been deleted as well