diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index f4a758aba8..4499fb8a86 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -2,9 +2,10 @@ import os import re from datetime import datetime -from typing import Dict +from typing import Any, Dict, Union from uuid import UUID +import phonenumbers from flask import current_app from notifications_utils.recipients import ( validate_and_format_email_address, @@ -48,6 +49,7 @@ NOTIFICATION_VIRUS_SCAN_FAILED, PINPOINT_PROVIDER, SMS_TYPE, + SNS_PROVIDER, BounceRateStatus, Notification, Service, @@ -67,9 +69,9 @@ def send_sms_to_provider(notification): provider = provider_to_use( SMS_TYPE, notification.id, + notification.to, notification.international, notification.reply_to_text, - template_id=notification.template_id, ) template_dict = dao_get_template_by_id(notification.template_id, notification.template_version).__dict__ @@ -337,9 +339,44 @@ def update_notification_to_sending(notification, provider): dao_update_notification(notification) -def provider_to_use(notification_type, notification_id, international=False, sender=None, template_id=None): - # TODO: remove the first option once we have pinpoint fully integrated - if Config.AWS_PINPOINT_SC_POOL_ID is None or Config.AWS_PINPOINT_DEFAULT_POOL_ID is None: +def provider_to_use( + notification_type: str, + notification_id: UUID, + to: Union[str, None] = None, + international: bool = False, + sender: Union[str, None] = None, +) -> Any: + """ + Get the provider to use for sending the notification. + SMS that are being sent with a dedicated number or to a US number should not use Pinpoint. + + Args: + notification_type (str): SMS or EMAIL. + notification_id (UUID): id of notification. Just used for logging. + to (str, optional): recipient. Defaults to None. + international (bool, optional): Recipient is international. Defaults to False. + sender (str, optional): reply_to_text to use. Defaults to None. + + Raises: + Exception: No active providers. + + Returns: + provider: Provider to use to send the notification. + """ + + has_dedicated_number = sender is not None and sender.startswith("+1") + sending_to_us_number = False + if to is not None: + match = next(iter(phonenumbers.PhoneNumberMatcher(to, "US")), None) + if match and phonenumbers.region_code_for_number(match.number) == "US": + sending_to_us_number = True + + if ( + has_dedicated_number + or sending_to_us_number + or current_app.config["AWS_PINPOINT_SC_POOL_ID"] is None + or current_app.config["AWS_PINPOINT_DEFAULT_POOL_ID"] is None + ): active_providers_in_order = [ p for p in get_provider_details_by_notification_type(notification_type, international) @@ -347,7 +384,9 @@ def provider_to_use(notification_type, notification_id, international=False, sen ] else: active_providers_in_order = [ - p for p in get_provider_details_by_notification_type(notification_type, international) if p.active + p + for p in get_provider_details_by_notification_type(notification_type, international) + if p.active and p.identifier != SNS_PROVIDER ] if not active_providers_in_order: diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index 571397827d..9858c46b67 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -51,6 +51,70 @@ from tests.conftest import set_config_values +class TestProviderToUse: + def test_should_use_pinpoint_for_sms_by_default(self, restore_provider_details, notify_api): + providers = provider_details_dao.get_provider_details_by_notification_type("sms") + for provider in providers: + if provider.identifier == "pinpoint": + provider.active = True + provider_details_dao.dao_update_provider_details(provider) + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+16135551234") + assert provider.name == "pinpoint" + + def test_should_use_sns_for_sms_if_dedicated_number(self, restore_provider_details, notify_api): + providers = provider_details_dao.get_provider_details_by_notification_type("sms") + for provider in providers: + if provider.identifier == "pinpoint": + provider.active = True + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+16135551234", False, "+12345678901") + assert provider.name == "sns" + + def test_should_use_sns_for_sms_if_sending_to_the_US(self, restore_provider_details, notify_api): + providers = provider_details_dao.get_provider_details_by_notification_type("sms") + for provider in providers: + if provider.identifier == "pinpoint": + provider.active = True + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+17065551234") + assert provider.name == "sns" + + @pytest.mark.parametrize("sc_pool_id, default_pool_id", [(None, "default_pool_id"), ("sc_pool_id", None)]) + def test_should_use_sns_if_pinpoint_not_configured(self, restore_provider_details, notify_api, sc_pool_id, default_pool_id): + providers = provider_details_dao.get_provider_details_by_notification_type("sms") + for provider in providers: + if provider.identifier == "pinpoint": + provider.active = True + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": sc_pool_id, + "AWS_PINPOINT_DEFAULT_POOL_ID": default_pool_id, + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+16135551234") + assert provider.name == "sns" + + @pytest.mark.skip(reason="Currently using only 1 SMS provider") def test_should_return_highest_priority_active_provider(restore_provider_details): providers = provider_details_dao.get_provider_details_by_notification_type("sms") @@ -84,21 +148,6 @@ def test_should_return_highest_priority_active_provider(restore_provider_details assert send_to_providers.provider_to_use("sms", "1234").name == first.identifier -def test_provider_to_use(restore_provider_details): - providers = provider_details_dao.get_provider_details_by_notification_type("sms") - first = providers[0] - - assert first.identifier == "sns" - - # provider is still SNS if SMS and sender is set - provider = send_to_providers.provider_to_use("sms", "1234", False, "+12345678901") - assert first.identifier == provider.name - - # provider is highest priority sms provider if sender is not set - provider = send_to_providers.provider_to_use("sms", "1234", False) - assert first.identifier == provider.name - - def test_should_send_personalised_template_to_correct_sms_provider_and_persist(sample_sms_template_with_html, mocker): db_notification = save_notification( create_notification(