Skip to content

Commit

Permalink
Use template category sending_vehicle in pinpoint (#2216)
Browse files Browse the repository at this point in the history
  • Loading branch information
sastels authored Jul 22, 2024
1 parent 32737c5 commit a0ac453
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 13 deletions.
7 changes: 7 additions & 0 deletions app/clients/sms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from enum import Enum

from app.clients import Client, ClientException


class SmsSendingVehicles(Enum):
SHORT_CODE = "short_code"
LONG_CODE = "long_code"


class SmsClientResponseException(ClientException):
"""
Base Exception for SmsClientsResponses
Expand Down
18 changes: 12 additions & 6 deletions app/clients/sms/aws_pinpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import boto3
import phonenumbers

from app.clients.sms import SmsClient
from app.clients.sms import SmsClient, SmsSendingVehicles


class AwsPinpointClient(SmsClient):
Expand All @@ -21,16 +21,22 @@ def init_app(self, current_app, statsd_client, *args, **kwargs):
def get_name(self):
return self.name

def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None, service_id=None):
def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None, service_id=None, sending_vehicle=None):
messageType = "TRANSACTIONAL"
matched = False
opted_out = False
response = {}

use_shortcode_pool = (
str(template_id) in self.current_app.config["AWS_PINPOINT_SC_TEMPLATE_IDS"]
or str(service_id) == self.current_app.config["NOTIFY_SERVICE_ID"]
)
if self.current_app.config["FF_TEMPLATE_CATEGORY"]:
use_shortcode_pool = (
sending_vehicle == SmsSendingVehicles.SHORT_CODE
or str(service_id) == self.current_app.config["NOTIFY_SERVICE_ID"]
)
else:
use_shortcode_pool = (
str(template_id) in self.current_app.config["AWS_PINPOINT_SC_TEMPLATE_IDS"]
or str(service_id) == self.current_app.config["NOTIFY_SERVICE_ID"]
)
if use_shortcode_pool:
pool_id = self.current_app.config["AWS_PINPOINT_SC_POOL_ID"]
else:
Expand Down
12 changes: 11 additions & 1 deletion app/delivery/send_to_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@

from app import bounce_rate_client, clients, document_download_client, statsd_client
from app.celery.research_mode_tasks import send_email_response, send_sms_response
from app.clients.sms import SmsSendingVehicles
from app.config import Config
from app.dao.notifications_dao import dao_update_notification
from app.dao.provider_details_dao import (
dao_toggle_sms_provider,
get_provider_details_by_notification_type,
)
from app.dao.template_categories_dao import dao_get_template_category_by_id
from app.dao.templates_dao import dao_get_template_by_id
from app.exceptions import (
DocumentDownloadException,
Expand Down Expand Up @@ -104,13 +106,21 @@ def send_sms_to_provider(notification):

else:
try:
template_category_id = template_dict.get("template_category_id")
if current_app.config["FF_TEMPLATE_CATEGORY"] and template_category_id is not None:
sending_vehicle = SmsSendingVehicles(
dao_get_template_category_by_id(template_category_id).sms_sending_vehicle
)
else:
sending_vehicle = None
reference = provider.send_sms(
to=validate_and_format_phone_number(notification.to, international=notification.international),
content=str(template),
reference=str(notification.id),
sender=notification.reply_to_text,
template_id=notification.template_id,
service_id=service.id,
service_id=notification.service_id,
sending_vehicle=sending_vehicle,
)
except Exception as e:
notification.billable_units = template.fragment_count
Expand Down
6 changes: 2 additions & 4 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
signer_inbound_sms,
signer_personalisation,
)
from app.clients.sms import SmsSendingVehicles
from app.encryption import check_hash, hashpw
from app.history_meta import Versioned

Expand Down Expand Up @@ -65,10 +66,7 @@
COMPLAINT_CALLBACK_TYPE = "complaint"
SERVICE_CALLBACK_TYPES = [DELIVERY_STATUS_CALLBACK_TYPE, COMPLAINT_CALLBACK_TYPE]

SHORT_CODE = "short_code"
LONG_CODE = "long_code"

sms_sending_vehicles = db.Enum(*[SHORT_CODE, LONG_CODE], name="sms_sending_vehicles")
sms_sending_vehicles = db.Enum(*[vehicle.value for vehicle in SmsSendingVehicles], name="sms_sending_vehicles")


def filter_null_value_fields(obj):
Expand Down
49 changes: 48 additions & 1 deletion tests/app/clients/test_aws_pinpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from app import aws_pinpoint_client
from app.clients.sms import SmsSendingVehicles
from tests.conftest import set_config_values


Expand Down Expand Up @@ -34,7 +35,7 @@ def test_send_sms_sends_to_default_pool(notify_api, mocker, sample_template, tem


@pytest.mark.serial
def test_send_sms_sends_to_shortcode_pool(notify_api, mocker, sample_template):
def test_send_sms_sends_sc_template_to_shortcode_pool_with_ff_false(notify_api, mocker, sample_template):
boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True)
mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True)
to = "6135555555"
Expand All @@ -48,6 +49,7 @@ def test_send_sms_sends_to_shortcode_pool(notify_api, mocker, sample_template):
"AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id",
"AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name",
"AWS_PINPOINT_SC_TEMPLATE_IDS": [str(sample_template.id)],
"FF_TEMPLATE_CATEGORY": False,
},
):
aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id)
Expand Down Expand Up @@ -114,6 +116,51 @@ def test_handles_opted_out_numbers(notify_api, mocker, sample_template):
assert aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id) == "opted_out"


@pytest.mark.serial
@pytest.mark.parametrize(
"FF_TEMPLATE_CATEGORY, sending_vehicle, expected_pool",
[
(False, None, "default_pool_id"),
(False, "long_code", "default_pool_id"),
(False, "short_code", "default_pool_id"),
(True, None, "default_pool_id"),
(True, "long_code", "default_pool_id"),
(True, "short_code", "sc_pool_id"),
],
)
def test_respects_sending_vehicle_if_FF_enabled(
notify_api, mocker, sample_template, FF_TEMPLATE_CATEGORY, sending_vehicle, expected_pool
):
boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True)
mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True)
to = "6135555555"
content = "foo"
reference = "ref"
sms_sending_vehicle = None if sending_vehicle is None else SmsSendingVehicles(sending_vehicle)

with set_config_values(
notify_api,
{
"AWS_PINPOINT_SC_POOL_ID": "sc_pool_id",
"AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id",
"AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name",
"AWS_PINPOINT_SC_TEMPLATE_IDS": [],
"FF_TEMPLATE_CATEGORY": FF_TEMPLATE_CATEGORY,
},
):
aws_pinpoint_client.send_sms(
to, content, reference=reference, template_id=sample_template.id, sending_vehicle=sms_sending_vehicle
)

boto_mock.send_text_message.assert_called_once_with(
DestinationPhoneNumber="+16135555555",
OriginationIdentity=expected_pool,
MessageBody=content,
MessageType="TRANSACTIONAL",
ConfigurationSetName="config_set_name",
)


@pytest.mark.serial
def test_send_sms_sends_international_without_pool_id(notify_api, mocker, sample_template):
boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True)
Expand Down
6 changes: 5 additions & 1 deletion tests/app/delivery/test_send_to_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_should_send_personalised_template_to_correct_sms_provider_and_persist(s
sender=current_app.config["FROM_NUMBER"],
template_id=sample_sms_template_with_html.id,
service_id=sample_sms_template_with_html.service_id,
sending_vehicle=None,
)

notification = Notification.query.filter_by(id=db_notification.id).one()
Expand Down Expand Up @@ -458,6 +459,7 @@ def test_send_sms_should_use_template_version_from_notification_not_latest(sampl
sender=current_app.config["FROM_NUMBER"],
template_id=sample_template.id,
service_id=sample_template.service_id,
sending_vehicle=ANY,
)

persisted_notification = notifications_dao.get_notification_by_id(db_notification.id)
Expand Down Expand Up @@ -537,7 +539,7 @@ def test_should_send_sms_with_downgraded_content(notify_db_session, mocker):
send_to_providers.send_sms_to_provider(db_notification)

aws_sns_client.send_sms.assert_called_once_with(
to=ANY, content=gsm_message, reference=ANY, sender=ANY, template_id=ANY, service_id=ANY
to=ANY, content=gsm_message, reference=ANY, sender=ANY, template_id=ANY, service_id=ANY, sending_vehicle=ANY
)


Expand All @@ -558,6 +560,7 @@ def test_send_sms_should_use_service_sms_sender(sample_service, sample_template,
sender=sms_sender.sms_sender,
template_id=ANY,
service_id=ANY,
sending_vehicle=ANY,
)


Expand Down Expand Up @@ -931,6 +934,7 @@ def test_should_handle_sms_sender_and_prefix_message(
reference=ANY,
template_id=ANY,
service_id=ANY,
sending_vehicle=ANY,
)


Expand Down

0 comments on commit a0ac453

Please sign in to comment.