diff --git a/api/api_gateway/api.py b/api/api_gateway/api.py index 1e4f07dd..6d1fd0d4 100644 --- a/api/api_gateway/api.py +++ b/api/api_gateway/api.py @@ -7,6 +7,7 @@ from fastapi.responses import RedirectResponse, JSONResponse from clients.notify import NotificationsAPIClient from requests import HTTPError +from sqlalchemy.engine.row import Row from sqlalchemy.exc import SQLAlchemyError, NoResultFound from sqlalchemy.sql.expression import func, cast from sqlalchemy.orm import Session @@ -394,7 +395,7 @@ def update_list( _authorized: bool = Depends(verify_token), ): try: - list = session.query(List).get(list_id) + list = session.get(List, list_id) if list is None: raise NoResultFound except SQLAlchemyError: @@ -426,7 +427,7 @@ def delete_list( _authorized: bool = Depends(verify_token), ): try: - list = session.query(List).get(list_id) + list = session.get(List, list_id) if list is None: raise NoResultFound except SQLAlchemyError: @@ -457,7 +458,7 @@ def reset_list( _authorized: bool = Depends(verify_token), ): try: - list = session.query(List).get(list_id) + list = session.get(List, list_id) if list is None: raise NoResultFound except SQLAlchemyError: @@ -513,7 +514,7 @@ def create_subscription( notifications_client = get_notify_client() try: - list = session.query(List).get(subscription_payload.list_id) + list = session.get(List, subscription_payload.list_id) if list is None: raise NoResultFound except SQLAlchemyError: @@ -613,7 +614,7 @@ def confirm_subscription( subscription_id, response: Response, session: Session = Depends(get_db) ): try: - subscription = session.query(Subscription).get(subscription_id) + subscription = session.get(Subscription, subscription_id) if subscription is None: raise NoResultFound @@ -651,7 +652,7 @@ def unsubscribe( notifications_client = get_notify_client() try: - subscription = session.query(Subscription).get(subscription_id) + subscription = session.get(Subscription, subscription_id) if subscription is None: raise NoResultFound except SQLAlchemyError: @@ -659,7 +660,7 @@ def unsubscribe( return {"error": "subscription not found"} try: - list = session.query(List).get(subscription.list_id) + list = session.get(List, subscription.list_id) email = subscription.email phone = subscription.phone @@ -800,6 +801,10 @@ def send_bulk_notify(subscription_count, send_payload, rows, recipient_limit=500 template_type = send_payload.template_type.lower() # Split notifications into separate calls based on limit for i, row in enumerate(rows): + # Convert SQLAlchemy Row objects to dicts + if type(row) is Row: + row = row._mapping + if i > 0 and (i % recipient_limit == 0): notify_bulk_subscribers.append(subscription_rows) diff --git a/api/requirements.txt b/api/requirements.txt index 18b2df80..d102b849 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -9,4 +9,4 @@ mangum==0.17.0 notifications-python-client==8.0.1 pydantic==1.10.12 psycopg2-binary==2.9.7 -SQLAlchemy==1.4.49 \ No newline at end of file +SQLAlchemy==2.0.21 \ No newline at end of file diff --git a/api/tests/api_gateway/test_api.py b/api/tests/api_gateway/test_api.py index 3203d891..f3ba7995 100644 --- a/api/tests/api_gateway/test_api.py +++ b/api/tests/api_gateway/test_api.py @@ -349,7 +349,7 @@ def test_edit_list_with_correct_id(session, client): assert response.json() == {"status": "OK"} assert response.status_code == 200 session.expire_all() - list = session.query(List).get(list.id) + list = session.get(List, list.id) assert list.name == "edited_name" assert list.language == "edited_language" assert list.service_id == "edited_service_id" @@ -372,7 +372,7 @@ def test_edit_list_without_supplying_service_id_and_name(session, client): assert response.json() == {"status": "OK"} assert response.status_code == 200 session.expire_all() - list = session.query(List).get(list.id) + list = session.get(List, list.id) assert list.subscribe_email_template_id == "ea974231-002b-4889-87f1-0b9cf48e9411" diff --git a/api/tests/api_gateway/test_api_send.py b/api/tests/api_gateway/test_api_send.py index 992338da..b4acec4e 100644 --- a/api/tests/api_gateway/test_api_send.py +++ b/api/tests/api_gateway/test_api_send.py @@ -4,6 +4,7 @@ from unittest.mock import patch from requests import HTTPError from models.Subscription import Subscription +from sqlalchemy import text @patch("api_gateway.api.get_notify_client") @@ -36,7 +37,7 @@ def test_send_email(mock_client, list_fixture, client, session): @patch("api_gateway.api.get_notify_client") def test_send_email_with_personalisation(mock_client, list_fixture, client, session): - session.execute("""TRUNCATE TABLE subscriptions""") + session.execute(text("TRUNCATE TABLE subscriptions")) session.commit() subscription = Subscription(