From 12b9571eac64823ebf66276ba796644ab924bb48 Mon Sep 17 00:00:00 2001 From: Steve Astels Date: Wed, 22 May 2024 11:12:31 -0400 Subject: [PATCH] Use Pinpoint by default (#2173) --- .env.example | 1 + app/clients/sms/aws_pinpoint.py | 9 ++- app/clients/sms/aws_sns.py | 8 +- app/config.py | 1 + app/delivery/send_to_providers.py | 64 +++++++++++++--- .../versions/0450_enable_pinpoint_provider.py | 19 +++++ tests/app/clients/test_aws_pinpoint.py | 73 +++++++++++++++++++ tests/app/dao/test_provider_details_dao.py | 7 +- tests/app/delivery/test_send_to_providers.py | 71 +++++++++++++----- 9 files changed, 215 insertions(+), 38 deletions(-) create mode 100644 migrations/versions/0450_enable_pinpoint_provider.py create mode 100644 tests/app/clients/test_aws_pinpoint.py diff --git a/.env.example b/.env.example index cb36eefda5..6557dd4a88 100644 --- a/.env.example +++ b/.env.example @@ -22,3 +22,4 @@ CONTACT_FORM_EMAIL_ADDRESS = "" AWS_PINPOINT_SC_POOL_ID= AWS_PINPOINT_SC_TEMPLATE_IDS= +AWS_PINPOINT_DEFAULT_POOL_ID= diff --git a/app/clients/sms/aws_pinpoint.py b/app/clients/sms/aws_pinpoint.py index 37140323c0..bdb3ba7fa7 100644 --- a/app/clients/sms/aws_pinpoint.py +++ b/app/clients/sms/aws_pinpoint.py @@ -14,7 +14,6 @@ class AwsPinpointClient(SmsClient): def init_app(self, current_app, statsd_client, *args, **kwargs): self._client = boto3.client("pinpoint-sms-voice-v2", region_name="ca-central-1") super(AwsPinpointClient, self).__init__(*args, **kwargs) - # super(SmsClient, self).__init__(*args, **kwargs) self.current_app = current_app self.name = "pinpoint" self.statsd_client = statsd_client @@ -22,11 +21,15 @@ def init_app(self, current_app, statsd_client, *args, **kwargs): def get_name(self): return self.name - def send_sms(self, to, content, reference, multi=True, sender=None): - pool_id = self.current_app.config["AWS_PINPOINT_SC_POOL_ID"] + def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None): messageType = "TRANSACTIONAL" matched = False + if template_id is not None and str(template_id) in self.current_app.config["AWS_PINPOINT_SC_TEMPLATE_IDS"]: + pool_id = self.current_app.config["AWS_PINPOINT_SC_POOL_ID"] + else: + pool_id = self.current_app.config["AWS_PINPOINT_DEFAULT_POOL_ID"] + for match in phonenumbers.PhoneNumberMatcher(to, "US"): matched = True to = phonenumbers.format_number(match.number, phonenumbers.PhoneNumberFormat.E164) diff --git a/app/clients/sms/aws_sns.py b/app/clients/sms/aws_sns.py index cf6fe3e914..4847754d72 100644 --- a/app/clients/sms/aws_sns.py +++ b/app/clients/sms/aws_sns.py @@ -2,7 +2,6 @@ from time import monotonic import boto3 -import botocore import phonenumbers from notifications_utils.statsd_decorators import statsd @@ -27,7 +26,7 @@ def get_name(self): return self.name @statsd(namespace="clients.sns") - def send_sms(self, to, content, reference, multi=True, sender=None): + def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None): matched = False for match in phonenumbers.PhoneNumberMatcher(to, "US"): @@ -66,12 +65,9 @@ def send_sms(self, to, content, reference, multi=True, sender=None): try: start_time = monotonic() response = client.publish(PhoneNumber=to, Message=content, MessageAttributes=attributes) - except botocore.exceptions.ClientError as e: - self.statsd_client.incr("clients.sns.error") - raise str(e) except Exception as e: self.statsd_client.incr("clients.sns.error") - raise str(e) + raise e finally: elapsed_time = monotonic() - start_time self.current_app.logger.info("AWS SNS request finished in {}".format(elapsed_time)) diff --git a/app/config.py b/app/config.py index fa8e0e389d..aab8422f27 100644 --- a/app/config.py +++ b/app/config.py @@ -267,6 +267,7 @@ class Config(object): AWS_SES_SECRET_KEY = os.getenv("AWS_SES_SECRET_KEY") AWS_PINPOINT_REGION = os.getenv("AWS_PINPOINT_REGION", "us-west-2") AWS_PINPOINT_SC_POOL_ID = os.getenv("AWS_PINPOINT_SC_POOL_ID", None) + AWS_PINPOINT_DEFAULT_POOL_ID = os.getenv("AWS_PINPOINT_DEFAULT_POOL_ID", None) AWS_PINPOINT_CONFIGURATION_SET_NAME = os.getenv("AWS_PINPOINT_CONFIGURATION_SET_NAME", "pinpoint-configuration") AWS_PINPOINT_SC_TEMPLATE_IDS = env.list("AWS_PINPOINT_SC_TEMPLATE_IDS", []) AWS_US_TOLL_FREE_NUMBER = os.getenv("AWS_US_TOLL_FREE_NUMBER") diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index c291bbd16a..c7de7a32c2 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -2,9 +2,10 @@ import os import re from datetime import datetime -from typing import Dict +from typing import Any, Dict, Optional from uuid import UUID +import phonenumbers from flask import current_app from notifications_utils.recipients import ( validate_and_format_email_address, @@ -48,6 +49,7 @@ NOTIFICATION_VIRUS_SCAN_FAILED, PINPOINT_PROVIDER, SMS_TYPE, + SNS_PROVIDER, BounceRateStatus, Notification, Service, @@ -67,9 +69,9 @@ def send_sms_to_provider(notification): provider = provider_to_use( SMS_TYPE, notification.id, + notification.to, notification.international, notification.reply_to_text, - template_id=notification.template_id, ) template_dict = dao_get_template_by_id(notification.template_id, notification.template_version).__dict__ @@ -105,6 +107,7 @@ def send_sms_to_provider(notification): content=str(template), reference=str(notification.id), sender=notification.reply_to_text, + template_id=notification.template_id, ) except Exception as e: notification.billable_units = template.fragment_count @@ -336,16 +339,55 @@ def update_notification_to_sending(notification, provider): dao_update_notification(notification) -def provider_to_use(notification_type, notification_id, international=False, sender=None, template_id=None): - # Temporary redirect setup for template IDs that are meant for the short code usage. - if notification_type == SMS_TYPE and template_id is not None and str(template_id) in Config.AWS_PINPOINT_SC_TEMPLATE_IDS: - return clients.get_client_by_name_and_type("pinpoint", SMS_TYPE) +def provider_to_use( + notification_type: str, + notification_id: UUID, + to: Optional[str] = None, + international: bool = False, + sender: Optional[str] = None, +) -> Any: + """ + Get the provider to use for sending the notification. + SMS that are being sent with a dedicated number or to a US number should not use Pinpoint. + + Args: + notification_type (str): SMS or EMAIL. + notification_id (UUID): id of notification. Just used for logging. + to (str, optional): recipient. Defaults to None. + international (bool, optional): Recipient is international. Defaults to False. + sender (str, optional): reply_to_text to use. Defaults to None. + + Raises: + Exception: No active providers. - active_providers_in_order = [ - p - for p in get_provider_details_by_notification_type(notification_type, international) - if p.active and p.identifier != PINPOINT_PROVIDER - ] + Returns: + provider: Provider to use to send the notification. + """ + + has_dedicated_number = sender is not None and sender.startswith("+1") + sending_to_us_number = False + if to is not None: + match = next(iter(phonenumbers.PhoneNumberMatcher(to, "US")), None) + if match and phonenumbers.region_code_for_number(match.number) == "US": + sending_to_us_number = True + + if ( + has_dedicated_number + or sending_to_us_number + or current_app.config["AWS_PINPOINT_SC_POOL_ID"] is None + or current_app.config["AWS_PINPOINT_DEFAULT_POOL_ID"] is None + ): + active_providers_in_order = [ + p + for p in get_provider_details_by_notification_type(notification_type, international) + if p.active and p.identifier != PINPOINT_PROVIDER + ] + else: + active_providers_in_order = [ + p + for p in get_provider_details_by_notification_type(notification_type, international) + if p.active and p.identifier != SNS_PROVIDER + ] if not active_providers_in_order: current_app.logger.error("{} {} failed as no active providers".format(notification_type, notification_id)) diff --git a/migrations/versions/0450_enable_pinpoint_provider.py b/migrations/versions/0450_enable_pinpoint_provider.py new file mode 100644 index 0000000000..0c2c8247dd --- /dev/null +++ b/migrations/versions/0450_enable_pinpoint_provider.py @@ -0,0 +1,19 @@ +""" + +Revision ID: 0450_enable_pinpoint_provider +Revises: 0449_update_magic_link_auth +Create Date: 2021-01-08 09:03:00 .214680 + +""" +from alembic import op + +revision = "0450_enable_pinpoint_provider" +down_revision = "0449_update_magic_link_auth" + + +def upgrade(): + op.execute("UPDATE provider_details set active=true where identifier in ('pinpoint');") + + +def downgrade(): + op.execute("UPDATE provider_details set active=false where identifier in ('pinpoint');") diff --git a/tests/app/clients/test_aws_pinpoint.py b/tests/app/clients/test_aws_pinpoint.py new file mode 100644 index 0000000000..ad7546d1ad --- /dev/null +++ b/tests/app/clients/test_aws_pinpoint.py @@ -0,0 +1,73 @@ +import pytest + +from app import aws_pinpoint_client +from tests.conftest import set_config_values + + +@pytest.mark.serial +def test_send_sms_sends_to_default_pool(notify_api, mocker, sample_template): + boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True) + mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True) + to = "6135555555" + content = "foo" + reference = "ref" + + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + "AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name", + "AWS_PINPOINT_SC_TEMPLATE_IDS": [], + }, + ): + aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id) + + boto_mock.send_text_message.assert_called_once_with( + DestinationPhoneNumber="+16135555555", + OriginationIdentity="default_pool_id", + MessageBody=content, + MessageType="TRANSACTIONAL", + ConfigurationSetName="config_set_name", + ) + + +@pytest.mark.serial +def test_send_sms_sends_to_shortcode_pool(notify_api, mocker, sample_template): + boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True) + mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True) + to = "6135555555" + content = "foo" + reference = "ref" + + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + "AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name", + "AWS_PINPOINT_SC_TEMPLATE_IDS": [str(sample_template.id)], + }, + ): + aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id) + + boto_mock.send_text_message.assert_called_once_with( + DestinationPhoneNumber="+16135555555", + OriginationIdentity="sc_pool_id", + MessageBody=content, + MessageType="TRANSACTIONAL", + ConfigurationSetName="config_set_name", + ) + + +def test_send_sms_returns_raises_error_if_there_is_no_valid_number_is_found(notify_api, mocker): + mocker.patch.object(aws_pinpoint_client, "_client", create=True) + mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True) + + to = "" + content = reference = "foo" + + with pytest.raises(ValueError) as excinfo: + aws_pinpoint_client.send_sms(to, content, reference) + + assert "No valid numbers found for SMS delivery" in str(excinfo.value) diff --git a/tests/app/dao/test_provider_details_dao.py b/tests/app/dao/test_provider_details_dao.py index 5b8b8e5348..6acce65192 100644 --- a/tests/app/dao/test_provider_details_dao.py +++ b/tests/app/dao/test_provider_details_dao.py @@ -241,9 +241,14 @@ def test_get_sms_provider_with_equal_priority_returns_provider( def test_get_current_sms_provider_returns_active_only(restore_provider_details): + # Note that we currently have two active sms providers: sns and pinpoint. current_provider = get_current_provider("sms") current_provider.active = False dao_update_provider_details(current_provider) + current_provider = get_current_provider("sms") + current_provider.active = False + dao_update_provider_details(current_provider) + new_current_provider = get_current_provider("sms") assert new_current_provider is None @@ -308,5 +313,5 @@ def test_dao_get_provider_stats(notify_db_session): assert result[5].identifier == "pinpoint" assert result[5].notification_type == "sms" assert result[5].supports_international is False - assert result[5].active is False + assert result[5].active is True assert result[5].current_month_billable_sms == 0 diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index ea91e1a503..a8637afdfe 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -51,6 +51,53 @@ from tests.conftest import set_config_values +class TestProviderToUse: + def test_should_use_pinpoint_for_sms_by_default(self, restore_provider_details, notify_api): + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+16135551234") + assert provider.name == "pinpoint" + + def test_should_use_sns_for_sms_if_dedicated_number(self, restore_provider_details, notify_api): + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+16135551234", False, "+12345678901") + assert provider.name == "sns" + + def test_should_use_sns_for_sms_if_sending_to_the_US(self, restore_provider_details, notify_api): + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+17065551234") + assert provider.name == "sns" + + @pytest.mark.parametrize("sc_pool_id, default_pool_id", [(None, "default_pool_id"), ("sc_pool_id", None)]) + def test_should_use_sns_if_pinpoint_not_configured(self, restore_provider_details, notify_api, sc_pool_id, default_pool_id): + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": sc_pool_id, + "AWS_PINPOINT_DEFAULT_POOL_ID": default_pool_id, + }, + ): + provider = send_to_providers.provider_to_use("sms", "1234", "+16135551234") + assert provider.name == "sns" + + @pytest.mark.skip(reason="Currently using only 1 SMS provider") def test_should_return_highest_priority_active_provider(restore_provider_details): providers = provider_details_dao.get_provider_details_by_notification_type("sms") @@ -84,21 +131,6 @@ def test_should_return_highest_priority_active_provider(restore_provider_details assert send_to_providers.provider_to_use("sms", "1234").name == first.identifier -def test_provider_to_use(restore_provider_details): - providers = provider_details_dao.get_provider_details_by_notification_type("sms") - first = providers[0] - - assert first.identifier == "sns" - - # provider is still SNS if SMS and sender is set - provider = send_to_providers.provider_to_use("sms", "1234", False, "+12345678901") - assert first.identifier == provider.name - - # provider is highest priority sms provider if sender is not set - provider = send_to_providers.provider_to_use("sms", "1234", False) - assert first.identifier == provider.name - - def test_should_send_personalised_template_to_correct_sms_provider_and_persist(sample_sms_template_with_html, mocker): db_notification = save_notification( create_notification( @@ -120,6 +152,7 @@ def test_should_send_personalised_template_to_correct_sms_provider_and_persist(s content="Sample service: Hello Jo\nHere is some HTML & entities", reference=str(db_notification.id), sender=current_app.config["FROM_NUMBER"], + template_id=sample_sms_template_with_html.id, ) notification = Notification.query.filter_by(id=db_notification.id).one() @@ -338,6 +371,7 @@ def test_send_sms_should_use_template_version_from_notification_not_latest(sampl content="Sample service: This is a template:\nwith a newline", reference=str(db_notification.id), sender=current_app.config["FROM_NUMBER"], + template_id=sample_template.id, ) persisted_notification = notifications_dao.get_notification_by_id(db_notification.id) @@ -416,7 +450,7 @@ def test_should_send_sms_with_downgraded_content(notify_db_session, mocker): send_to_providers.send_sms_to_provider(db_notification) - aws_sns_client.send_sms.assert_called_once_with(to=ANY, content=gsm_message, reference=ANY, sender=ANY) + aws_sns_client.send_sms.assert_called_once_with(to=ANY, content=gsm_message, reference=ANY, sender=ANY, template_id=ANY) def test_send_sms_should_use_service_sms_sender(sample_service, sample_template, mocker): @@ -429,7 +463,9 @@ def test_send_sms_should_use_service_sms_sender(sample_service, sample_template, db_notification, ) - app.aws_sns_client.send_sms.assert_called_once_with(to=ANY, content=ANY, reference=ANY, sender=sms_sender.sms_sender) + app.aws_sns_client.send_sms.assert_called_once_with( + to=ANY, content=ANY, reference=ANY, sender=sms_sender.sms_sender, template_id=ANY + ) @pytest.mark.parametrize("research_mode,key_type", [(True, KEY_TYPE_NORMAL), (False, KEY_TYPE_TEST)]) @@ -800,6 +836,7 @@ def test_should_handle_sms_sender_and_prefix_message( sender=expected_sender, to=ANY, reference=ANY, + template_id=ANY, )