diff --git a/app/clients/sms/__init__.py b/app/clients/sms/__init__.py index 8d6472d19d..88ab822075 100644 --- a/app/clients/sms/__init__.py +++ b/app/clients/sms/__init__.py @@ -1,6 +1,13 @@ +from enum import Enum + from app.clients import Client, ClientException +class SmsSendingVehicles(Enum): + SHORT_CODE = "short_code" + LONG_CODE = "long_code" + + class SmsClientResponseException(ClientException): """ Base Exception for SmsClientsResponses diff --git a/app/clients/sms/aws_pinpoint.py b/app/clients/sms/aws_pinpoint.py index 4b18b1c7a3..57c58c9f13 100644 --- a/app/clients/sms/aws_pinpoint.py +++ b/app/clients/sms/aws_pinpoint.py @@ -3,7 +3,7 @@ import boto3 import phonenumbers -from app.clients.sms import SmsClient +from app.clients.sms import SmsClient, SmsSendingVehicles class AwsPinpointClient(SmsClient): @@ -21,16 +21,22 @@ def init_app(self, current_app, statsd_client, *args, **kwargs): def get_name(self): return self.name - def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None, service_id=None): + def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None, service_id=None, sending_vehicle=None): messageType = "TRANSACTIONAL" matched = False opted_out = False response = {} - use_shortcode_pool = ( - str(template_id) in self.current_app.config["AWS_PINPOINT_SC_TEMPLATE_IDS"] - or str(service_id) == self.current_app.config["NOTIFY_SERVICE_ID"] - ) + if self.current_app.config["FF_TEMPLATE_CATEGORY"]: + use_shortcode_pool = ( + sending_vehicle == SmsSendingVehicles.SHORT_CODE + or str(service_id) == self.current_app.config["NOTIFY_SERVICE_ID"] + ) + else: + use_shortcode_pool = ( + str(template_id) in self.current_app.config["AWS_PINPOINT_SC_TEMPLATE_IDS"] + or str(service_id) == self.current_app.config["NOTIFY_SERVICE_ID"] + ) if use_shortcode_pool: pool_id = self.current_app.config["AWS_PINPOINT_SC_POOL_ID"] else: diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index 5979517c75..5bcaed1f34 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -22,12 +22,14 @@ from app import bounce_rate_client, clients, document_download_client, statsd_client from app.celery.research_mode_tasks import send_email_response, send_sms_response +from app.clients.sms import SmsSendingVehicles from app.config import Config from app.dao.notifications_dao import dao_update_notification from app.dao.provider_details_dao import ( dao_toggle_sms_provider, get_provider_details_by_notification_type, ) +from app.dao.template_categories_dao import dao_get_template_category_by_id from app.dao.templates_dao import dao_get_template_by_id from app.exceptions import ( DocumentDownloadException, @@ -104,13 +106,21 @@ def send_sms_to_provider(notification): else: try: + template_category_id = template_dict.get("template_category_id") + if current_app.config["FF_TEMPLATE_CATEGORY"] and template_category_id is not None: + sending_vehicle = SmsSendingVehicles( + dao_get_template_category_by_id(template_category_id).sms_sending_vehicle + ) + else: + sending_vehicle = None reference = provider.send_sms( to=validate_and_format_phone_number(notification.to, international=notification.international), content=str(template), reference=str(notification.id), sender=notification.reply_to_text, template_id=notification.template_id, - service_id=service.id, + service_id=notification.service_id, + sending_vehicle=sending_vehicle, ) except Exception as e: notification.billable_units = template.fragment_count diff --git a/app/models.py b/app/models.py index 86f47dac35..74000e2b09 100644 --- a/app/models.py +++ b/app/models.py @@ -38,6 +38,7 @@ signer_inbound_sms, signer_personalisation, ) +from app.clients.sms import SmsSendingVehicles from app.encryption import check_hash, hashpw from app.history_meta import Versioned @@ -65,10 +66,7 @@ COMPLAINT_CALLBACK_TYPE = "complaint" SERVICE_CALLBACK_TYPES = [DELIVERY_STATUS_CALLBACK_TYPE, COMPLAINT_CALLBACK_TYPE] -SHORT_CODE = "short_code" -LONG_CODE = "long_code" - -sms_sending_vehicles = db.Enum(*[SHORT_CODE, LONG_CODE], name="sms_sending_vehicles") +sms_sending_vehicles = db.Enum(*[vehicle.value for vehicle in SmsSendingVehicles], name="sms_sending_vehicles") def filter_null_value_fields(obj): diff --git a/tests/app/clients/test_aws_pinpoint.py b/tests/app/clients/test_aws_pinpoint.py index b681025c1b..b913b1c39b 100644 --- a/tests/app/clients/test_aws_pinpoint.py +++ b/tests/app/clients/test_aws_pinpoint.py @@ -1,6 +1,7 @@ import pytest from app import aws_pinpoint_client +from app.clients.sms import SmsSendingVehicles from tests.conftest import set_config_values @@ -34,7 +35,7 @@ def test_send_sms_sends_to_default_pool(notify_api, mocker, sample_template, tem @pytest.mark.serial -def test_send_sms_sends_to_shortcode_pool(notify_api, mocker, sample_template): +def test_send_sms_sends_sc_template_to_shortcode_pool_with_ff_false(notify_api, mocker, sample_template): boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True) mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True) to = "6135555555" @@ -48,6 +49,7 @@ def test_send_sms_sends_to_shortcode_pool(notify_api, mocker, sample_template): "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", "AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name", "AWS_PINPOINT_SC_TEMPLATE_IDS": [str(sample_template.id)], + "FF_TEMPLATE_CATEGORY": False, }, ): aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id) @@ -114,6 +116,51 @@ def test_handles_opted_out_numbers(notify_api, mocker, sample_template): assert aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id) == "opted_out" +@pytest.mark.serial +@pytest.mark.parametrize( + "FF_TEMPLATE_CATEGORY, sending_vehicle, expected_pool", + [ + (False, None, "default_pool_id"), + (False, "long_code", "default_pool_id"), + (False, "short_code", "default_pool_id"), + (True, None, "default_pool_id"), + (True, "long_code", "default_pool_id"), + (True, "short_code", "sc_pool_id"), + ], +) +def test_respects_sending_vehicle_if_FF_enabled( + notify_api, mocker, sample_template, FF_TEMPLATE_CATEGORY, sending_vehicle, expected_pool +): + boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True) + mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True) + to = "6135555555" + content = "foo" + reference = "ref" + sms_sending_vehicle = None if sending_vehicle is None else SmsSendingVehicles(sending_vehicle) + + with set_config_values( + notify_api, + { + "AWS_PINPOINT_SC_POOL_ID": "sc_pool_id", + "AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id", + "AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name", + "AWS_PINPOINT_SC_TEMPLATE_IDS": [], + "FF_TEMPLATE_CATEGORY": FF_TEMPLATE_CATEGORY, + }, + ): + aws_pinpoint_client.send_sms( + to, content, reference=reference, template_id=sample_template.id, sending_vehicle=sms_sending_vehicle + ) + + boto_mock.send_text_message.assert_called_once_with( + DestinationPhoneNumber="+16135555555", + OriginationIdentity=expected_pool, + MessageBody=content, + MessageType="TRANSACTIONAL", + ConfigurationSetName="config_set_name", + ) + + @pytest.mark.serial def test_send_sms_sends_international_without_pool_id(notify_api, mocker, sample_template): boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True) diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index d0c0d6812e..56b0eef3de 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -238,6 +238,7 @@ def test_should_send_personalised_template_to_correct_sms_provider_and_persist(s sender=current_app.config["FROM_NUMBER"], template_id=sample_sms_template_with_html.id, service_id=sample_sms_template_with_html.service_id, + sending_vehicle=None, ) notification = Notification.query.filter_by(id=db_notification.id).one() @@ -458,6 +459,7 @@ def test_send_sms_should_use_template_version_from_notification_not_latest(sampl sender=current_app.config["FROM_NUMBER"], template_id=sample_template.id, service_id=sample_template.service_id, + sending_vehicle=ANY, ) persisted_notification = notifications_dao.get_notification_by_id(db_notification.id) @@ -537,7 +539,7 @@ def test_should_send_sms_with_downgraded_content(notify_db_session, mocker): send_to_providers.send_sms_to_provider(db_notification) aws_sns_client.send_sms.assert_called_once_with( - to=ANY, content=gsm_message, reference=ANY, sender=ANY, template_id=ANY, service_id=ANY + to=ANY, content=gsm_message, reference=ANY, sender=ANY, template_id=ANY, service_id=ANY, sending_vehicle=ANY ) @@ -558,6 +560,7 @@ def test_send_sms_should_use_service_sms_sender(sample_service, sample_template, sender=sms_sender.sms_sender, template_id=ANY, service_id=ANY, + sending_vehicle=ANY, ) @@ -931,6 +934,7 @@ def test_should_handle_sms_sender_and_prefix_message( reference=ANY, template_id=ANY, service_id=ANY, + sending_vehicle=ANY, )