diff --git a/.github/workflows/export_github_data.yml b/.github/workflows/export_github_data.yml
index db925884..2cd11121 100644
--- a/.github/workflows/export_github_data.yml
+++ b/.github/workflows/export_github_data.yml
@@ -14,7 +14,7 @@ jobs:
DNS_PROXY_FORWARDTOSENTINEL: "true"
DNS_PROXY_LOGANALYTICSWORKSPACEID: ${{ secrets.LOG_ANALYTICS_WORKSPACE_ID }}
DNS_PROXY_LOGANALYTICSSHAREDKEY: ${{ secrets.LOG_ANALYTICS_WORKSPACE_KEY }}
- - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
- name: Export Data
uses: cds-snc/github-repository-metadata-exporter@main
with:
diff --git a/.github/workflows/ossf-scorecard.yml b/.github/workflows/ossf-scorecard.yml
index 87010d01..556ccb91 100644
--- a/.github/workflows/ossf-scorecard.yml
+++ b/.github/workflows/ossf-scorecard.yml
@@ -20,7 +20,7 @@ jobs:
steps:
- name: "Checkout code"
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
with:
persist-credentials: false
diff --git a/.github/workflows/s3-backup.yml b/.github/workflows/s3-backup.yml
index 6a8e9670..d850d5e0 100644
--- a/.github/workflows/s3-backup.yml
+++ b/.github/workflows/s3-backup.yml
@@ -10,7 +10,7 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
with:
fetch-depth: 0 # retrieve all history
diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py
new file mode 100644
index 00000000..823535a3
--- /dev/null
+++ b/notifications_utils/clients/redis/annual_limit.py
@@ -0,0 +1,188 @@
+"""
+This module stores daily notification counts and annual limit statuses for a service in Redis using a hash structure:
+
+
+annual-limit: {
+ {service_id}: {
+ notifications: {
+ sms_delivered: int,
+ email_delivered: int,
+ sms_failed: int,
+ email_failed: int
+ },
+ status: {
+ near_sms_limit: Datetime,
+ near_email_limit: Datetime,
+ over_sms_limit: Datetime,
+ over_email_limit: Datetime
+ seeded_at: Datetime
+ }
+ }
+}
+
+
+"""
+
+from datetime import datetime
+
+from notifications_utils.clients.redis.redis_client import RedisClient
+
+SMS_DELIVERED = "sms_delivered"
+EMAIL_DELIVERED = "email_delivered"
+SMS_FAILED = "sms_failed"
+EMAIL_FAILED = "email_failed"
+
+NOTIFICATIONS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED]
+
+NEAR_SMS_LIMIT = "near_sms_limit"
+NEAR_EMAIL_LIMIT = "near_email_limit"
+OVER_SMS_LIMIT = "over_sms_limit"
+OVER_EMAIL_LIMIT = "over_email_limit"
+SEEDED_AT = "seeded_at"
+
+STATUSES = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT]
+
+
+def annual_limit_notifications_key(service_id):
+ """
+ Generates the Redis hash key for storing daily metrics of a service.
+ """
+ return f"annual-limit:{service_id}:notifications"
+
+
+def annual_limit_status_key(service_id):
+ """
+ Generates the Redis hash key for storing annual limit information of a service.
+ """
+ return f"annual-limit:{service_id}:status"
+
+
+def decode_byte_dict(dict: dict, value_type=str):
+ """
+ Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings.
+ """
+ # Check if expected_value_type is one of the allowed types
+ if value_type not in {int, float, str}:
+ raise ValueError("expected_value_type must be int, float, or str")
+ return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items() if dict.items()}
+
+
+class RedisAnnualLimit:
+ def __init__(self, redis: RedisClient):
+ self._redis_client = redis
+
+ def init_app(self, app, *args, **kwargs):
+ pass
+
+ def increment_notification_count(self, service_id: str, field: str):
+ self._redis_client.increment_hash_value(annual_limit_notifications_key(service_id), field)
+
+ def get_notification_count(self, service_id: str, field: str):
+ """
+ Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.)
+ """
+ return int(self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field))
+
+ def get_all_notification_counts(self, service_id: str):
+ """
+ Retrieves all daily notification metrics for a service.
+ """
+ return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int)
+
+ def reset_all_notification_counts(self, service_ids=None):
+ """
+ Resets all daily notification metrics.
+ :param: service_ids: list of service_ids to reset, if None, resets all services
+ """
+ hashes = (
+ annual_limit_notifications_key("*")
+ if not service_ids
+ else [annual_limit_notifications_key(service_id) for service_id in service_ids]
+ )
+
+ self._redis_client.delete_hash_fields(hashes=hashes)
+
+ def seed_annual_limit_notifications(self, service_id: str, mapping: dict):
+ self._redis_client.bulk_set_hash_fields(key=annual_limit_notifications_key(service_id), mapping=mapping)
+
+ def was_seeded_today(self, service_id):
+ last_seeded_time = self.get_seeded_at(service_id)
+ return last_seeded_time == datetime.utcnow().strftime("%Y-%m-%d") if last_seeded_time else False
+
+ def get_seeded_at(self, service_id: str):
+ seeded_at = self._redis_client.get_hash_field(annual_limit_status_key(service_id), SEEDED_AT)
+ return seeded_at and seeded_at.decode("utf-8")
+
+ def set_seeded_at(self, service_id):
+ self._redis_client.set_hash_value(annual_limit_status_key(service_id), SEEDED_AT, datetime.utcnow().strftime("%Y-%m-%d"))
+
+ def clear_notification_counts(self, service_id: str):
+ """
+ Clears all daily notification metrics for a service.
+ """
+ self._redis_client.expire(annual_limit_notifications_key(service_id), -1)
+
+ def set_annual_limit_status(self, service_id: str, field: str, value: datetime):
+ """
+ Sets the status (e.g., 'nearing_limit', 'over_limit') in the annual limits Redis hash.
+ """
+ self._redis_client.set_hash_value(annual_limit_status_key(service_id), field, value.strftime("%Y-%m-%d"))
+
+ def get_annual_limit_status(self, service_id: str, field: str):
+ """
+ Retrieves the value of a specific annual limit status from the Redis hash.
+ """
+ response = self._redis_client.get_hash_field(annual_limit_status_key(service_id), field)
+ return response.decode("utf-8") if response is not None else None
+
+ def get_all_annual_limit_statuses(self, service_id: str):
+ return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id)))
+
+ def clear_annual_limit_statuses(self, service_id: str):
+ self._redis_client.expire(f"{annual_limit_status_key(service_id)}", -1)
+
+ # Helper methods for daily metrics
+ def increment_sms_delivered(self, service_id: str):
+ self.increment_notification_count(service_id, SMS_DELIVERED)
+
+ def increment_sms_failed(self, service_id: str):
+ self.increment_notification_count(service_id, SMS_FAILED)
+
+ def increment_email_delivered(self, service_id: str):
+ self.increment_notification_count(service_id, EMAIL_DELIVERED)
+
+ def increment_email_failed(self, service_id: str):
+ self.increment_notification_count(service_id, EMAIL_FAILED)
+
+ # Helper methods for annual limits
+ def set_nearing_sms_limit(self, service_id: str):
+ self.set_annual_limit_status(service_id, NEAR_SMS_LIMIT, datetime.utcnow())
+
+ def set_nearing_email_limit(self, service_id: str):
+ self.set_annual_limit_status(service_id, NEAR_EMAIL_LIMIT, datetime.utcnow())
+
+ def set_over_sms_limit(self, service_id: str):
+ self.set_annual_limit_status(service_id, OVER_SMS_LIMIT, datetime.utcnow())
+
+ def set_over_email_limit(self, service_id: str):
+ self.set_annual_limit_status(service_id, OVER_EMAIL_LIMIT, datetime.utcnow())
+
+ def check_has_warning_been_sent(self, service_id: str, message_type: str):
+ """
+ Check if an annual limit warning email has been sent to the service.
+ Returns None if no warning has been sent, otherwise returns the date the
+ last warning was issued.
+ When a service's annual limit is increased this value is reset.
+ """
+ field_to_fetch = NEAR_SMS_LIMIT if message_type == "sms" else NEAR_EMAIL_LIMIT
+ return self.get_annual_limit_status(service_id, field_to_fetch)
+
+ def check_has_over_limit_been_sent(self, service_id: str, message_type: str):
+ """
+ Check if an annual limit exceeded email has been sent to the service.
+ Returns None if no exceeded email has been sent, otherwise returns the date the
+ last exceeded email was issued.
+ When a service's annual limit is increased this value is reset.
+ """
+ field_to_fetch = OVER_SMS_LIMIT if message_type == "sms" else OVER_EMAIL_LIMIT
+ return self.get_annual_limit_status(service_id, field_to_fetch)
diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py
index c85b2678..3f4831ab 100644
--- a/notifications_utils/clients/redis/redis_client.py
+++ b/notifications_utils/clients/redis/redis_client.py
@@ -1,7 +1,7 @@
import numbers
import uuid
from time import time
-from typing import Any, Dict
+from typing import Any, Dict, Optional
from flask import current_app
from flask_redis import FlaskRedis
@@ -81,6 +81,64 @@ def delete_cache_keys_by_pattern(self, pattern):
return self.scripts["delete-keys-by-pattern"](args=[pattern])
return 0
+ # TODO: Refactor and simplify this to use HEXPIRE when we upgrade Redis to 7.4.0
+ def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None, raise_exception=False):
+ """Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely.
+
+ Args:
+ hashes (str|list): The hash pattern or list of hash keys to delete fields from.
+ fields (list): A list of fields to delete from the hashes. If `None`, then all fields are deleted.
+
+ Returns:
+ _type_: _description_
+ """
+ if self.active:
+ try:
+ hashes = [prepare_value(h) for h in hashes] if isinstance(hashes, list) else prepare_value(hashes)
+ # When fields are passed in, use the list as is
+ # When hashes is a list, and no fields are passed in, fetch the fields from the first hash in the list
+ # otherwise we know we're going scan iterate over a pattern so we'll fetch the fields on the first pass in the loop below
+ fields = (
+ [prepare_value(f) for f in fields]
+ if fields is not None
+ else self.redis_store.hkeys(hashes[0])
+ if isinstance(hashes, list)
+ else None
+ )
+ # Use a pipeline to atomically delete fields from each hash.
+ pipe = self.redis_store.pipeline()
+ # if hashes is not a list, we're scan iterating over keys matching a pattern
+ for key in hashes if isinstance(hashes, list) else self.redis_store.scan_iter(hashes):
+ if not fields:
+ fields = self.redis_store.hkeys(key)
+ key = prepare_value(key)
+ pipe.hdel(key, *fields)
+ result = pipe.execute()
+ # TODO: May need to double check that the pipeline result count matches the number of hashes deleted
+ # and retry any failures
+ return result
+ except Exception as e:
+ self.__handle_exception(e, raise_exception, "expire_hash_fields", hashes)
+ return False
+
+ def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception=False):
+ """
+ Bulk set hash fields.
+ :param pattern: the pattern to match keys
+ :param mappting: the mapping of fields to set
+ :param raise_exception: True if we should allow the exception to bubble up
+ """
+ if self.active:
+ try:
+ if pattern:
+ for key in self.redis_store.scan_iter(pattern):
+ self.redis_store.hmset(key, mapping)
+ if key:
+ return self.redis_store.hmset(key, mapping)
+ except Exception as e:
+ self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", pattern)
+ return False
+
def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False):
"""
Rate limiting.
@@ -228,17 +286,44 @@ def get(self, key, raise_exception=False):
return None
+ def set_hash_value(self, key, field, value, raise_exception=False):
+ key = prepare_value(key)
+ field = prepare_value(field)
+ value = prepare_value(value)
+
+ if self.active:
+ try:
+ return self.redis_store.hset(key, field, value)
+ except Exception as e:
+ self.__handle_exception(e, raise_exception, "set_hash_value", key)
+
+ return None
+
def decrement_hash_value(self, key, value, raise_exception=False):
return self.increment_hash_value(key, value, raise_exception, incr_by=-1)
def increment_hash_value(self, key, value, raise_exception=False, incr_by=1):
key = prepare_value(key)
value = prepare_value(value)
+
if self.active:
try:
return self.redis_store.hincrby(key, value, incr_by)
except Exception as e:
self.__handle_exception(e, raise_exception, "increment_hash_value", key)
+ return None
+
+ def get_hash_field(self, key, field, raise_exception=False):
+ key = prepare_value(key)
+ field = prepare_value(field)
+
+ if self.active:
+ try:
+ return self.redis_store.hget(key, field)
+ except Exception as e:
+ self.__handle_exception(e, raise_exception, "get_hash_field", key)
+
+ return None
def get_all_from_hash(self, key, raise_exception=False):
key = prepare_value(key)
@@ -248,6 +333,8 @@ def get_all_from_hash(self, key, raise_exception=False):
except Exception as e:
self.__handle_exception(e, raise_exception, "get_all_from_hash", key)
+ return None
+
def set_hash_and_expire(self, key, values, expire_in_seconds, raise_exception=False):
key = prepare_value(key)
values = {prepare_value(k): prepare_value(v) for k, v in values.items()}
@@ -258,6 +345,8 @@ def set_hash_and_expire(self, key, values, expire_in_seconds, raise_exception=Fa
except Exception as e:
self.__handle_exception(e, raise_exception, "set_hash_and_expire", key)
+ return None
+
def expire(self, key, expire_in_seconds, raise_exception=False):
key = prepare_value(key)
if self.active:
diff --git a/notifications_utils/formatters.py b/notifications_utils/formatters.py
index 79ee215f..c6c48d67 100644
--- a/notifications_utils/formatters.py
+++ b/notifications_utils/formatters.py
@@ -417,7 +417,7 @@ def list(self, body, ordered=True):
'
'
""
''
- ''
+ ''
"{}"
" "
" | "
@@ -429,7 +429,7 @@ def list(self, body, ordered=True):
''
""
''
- ' | "
@@ -440,8 +440,8 @@ def list(self, body, ordered=True):
def list_item(self, text):
return (
- ''
+ ''
"{}"
""
).format(text.strip())
diff --git a/notifications_utils/jinja_templates/email/email_preview_template.jinja2 b/notifications_utils/jinja_templates/email/email_preview_template.jinja2
index 3224d64a..9a7d60af 100644
--- a/notifications_utils/jinja_templates/email/email_preview_template.jinja2
+++ b/notifications_utils/jinja_templates/email/email_preview_template.jinja2
@@ -75,7 +75,7 @@
{% endif %}
-
+
{{ body }}
diff --git a/notifications_utils/jinja_templates/email/email_template.jinja2 b/notifications_utils/jinja_templates/email/email_template.jinja2
index 5b69eb67..70773033 100644
--- a/notifications_utils/jinja_templates/email/email_template.jinja2
+++ b/notifications_utils/jinja_templates/email/email_template.jinja2
@@ -51,8 +51,8 @@
style="border-collapse: collapse; width:100% !important; max-width: 580px; margin: 0 auto;"
>
-
- {{ body|safe }}
+ |
+ {{ body|safe }}
|
diff --git a/notifications_utils/template.py b/notifications_utils/template.py
index 5bccecf4..c6e00d77 100644
--- a/notifications_utils/template.py
+++ b/notifications_utils/template.py
@@ -375,6 +375,8 @@ def __init__(
self.allow_html = allow_html
self.alt_text_en = alt_text_en
self.alt_text_fr = alt_text_fr
+ self.text_direction_rtl = template.get("text_direction_rtl", False)
+
# set this again to make sure the correct either utils / downstream local jinja is used
# however, don't set if we are in a test environment (to preserve the above mock)
if "pytest" not in sys.modules:
@@ -416,6 +418,7 @@ def __str__(self):
"brand_name": self.brand_name,
"alt_text_en": self.alt_text_en,
"alt_text_fr": self.alt_text_fr,
+ "text_direction_rtl": self.text_direction_rtl,
}
)
@@ -483,6 +486,7 @@ def __init__(
self.allow_html = allow_html
self.alt_text_en = alt_text_en
self.alt_text_fr = alt_text_fr
+ self.text_direction_rtl = template.get("text_direction_rtl", False)
def __str__(self):
return Markup(
@@ -509,6 +513,7 @@ def __str__(self):
"asset_domain": self.asset_domain,
"alt_text_en": self.alt_text_en,
"alt_text_fr": self.alt_text_fr,
+ "text_direction_rtl": self.text_direction_rtl,
}
)
)
diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py
new file mode 100644
index 00000000..04adc1b9
--- /dev/null
+++ b/tests/test_annual_limit.py
@@ -0,0 +1,298 @@
+import uuid
+from datetime import datetime
+from unittest.mock import Mock
+
+import fakeredis
+import pytest
+from freezegun import freeze_time
+from notifications_utils.clients.redis.annual_limit import (
+ EMAIL_DELIVERED,
+ EMAIL_FAILED,
+ NEAR_EMAIL_LIMIT,
+ NEAR_SMS_LIMIT,
+ OVER_EMAIL_LIMIT,
+ OVER_SMS_LIMIT,
+ SMS_DELIVERED,
+ SMS_FAILED,
+ RedisAnnualLimit,
+ annual_limit_notifications_key,
+ annual_limit_status_key,
+)
+from notifications_utils.clients.redis.redis_client import RedisClient
+
+
+@pytest.fixture(scope="function")
+def mock_notification_count_types():
+ return [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED]
+
+
+@pytest.fixture(scope="function")
+def mock_annual_limit_statuses():
+ return [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT]
+
+
+@pytest.fixture(scope="function")
+def mocked_redis_pipeline():
+ return Mock()
+
+
+@pytest.fixture
+def mock_redis_client(app, mocked_redis_pipeline, mocker):
+ app.config["REDIS_ENABLED"] = True
+ return build_redis_client(app, mocked_redis_pipeline, mocker)
+
+
+@pytest.fixture(scope="function")
+def better_mocked_redis_client(app):
+ app.config["REDIS_ENABLED"] = True
+ redis_client = RedisClient()
+ redis_client.redis_store = fakeredis.FakeStrictRedis(version=6) # type: ignore
+ redis_client.active = True
+ return redis_client
+
+
+@pytest.fixture
+def redis_annual_limit(mock_redis_client):
+ return RedisAnnualLimit(mock_redis_client)
+
+
+def build_redis_client(app, mocked_redis_pipeline, mocker):
+ redis_client = RedisClient()
+ redis_client.init_app(app)
+ return redis_client
+
+
+def build_annual_limit_client(mocker, better_mocked_redis_client):
+ annual_limit_client = RedisAnnualLimit(better_mocked_redis_client)
+ return annual_limit_client
+
+
+@pytest.fixture(scope="function")
+def mock_annual_limit_client(better_mocked_redis_client, mocker):
+ return RedisAnnualLimit(better_mocked_redis_client)
+
+
+@pytest.fixture(scope="function")
+def mocked_service_id():
+ return str(uuid.uuid4())
+
+
+def test_notifications_key(mocked_service_id):
+ expected_key = f"annual-limit:{mocked_service_id}:notifications"
+ assert annual_limit_notifications_key(mocked_service_id) == expected_key
+
+
+def test_annual_limits_key(mocked_service_id):
+ expected_key = f"annual-limit:{mocked_service_id}:status"
+ assert annual_limit_status_key(mocked_service_id) == expected_key
+
+
+@pytest.mark.parametrize(
+ "increment_by, metric",
+ [
+ (1, SMS_DELIVERED),
+ (1, EMAIL_DELIVERED),
+ (1, SMS_FAILED),
+ (1, EMAIL_FAILED),
+ (2, SMS_DELIVERED),
+ (2, EMAIL_DELIVERED),
+ (2, SMS_FAILED),
+ (2, EMAIL_FAILED),
+ ],
+)
+def test_increment_notification_count(mock_annual_limit_client, mocked_service_id, metric, increment_by):
+ for _ in range(increment_by):
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, metric)
+ counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id)
+ assert int(counts[metric]) == increment_by
+
+
+def test_get_notification_count(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, SMS_DELIVERED)
+ result = mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED)
+ assert result == 1
+
+
+def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
+ assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4
+
+
+def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
+ assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4
+ mock_annual_limit_client.clear_notification_counts(mocked_service_id)
+ assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0
+
+
+@pytest.mark.parametrize(
+ "service_ids",
+ [
+ [
+ str(uuid.uuid4()),
+ str(uuid.uuid4()),
+ str(uuid.uuid4()),
+ str(uuid.uuid4()),
+ ]
+ ],
+)
+def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notification_count_types, service_ids):
+ for service_id in service_ids:
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(service_id, field)
+ assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4
+
+ mock_annual_limit_client.reset_all_notification_counts()
+
+ for service_id in service_ids:
+ assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 0
+
+
+def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow())
+ result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT)
+ assert result == datetime.utcnow().strftime("%Y-%m-%d")
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_get_annual_limit_status(mock_annual_limit_client, mocked_service_id):
+ near_limit_date = datetime.utcnow()
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, near_limit_date)
+ result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT)
+ assert result == near_limit_date.strftime("%Y-%m-%d")
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
+ for status in mock_annual_limit_statuses:
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
+ assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4
+ mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id)
+ assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+@pytest.mark.parametrize("seeded_at_value, expected_value", [(b"2024-10-25", True), (None, False)])
+def test_was_seeded_today(mock_annual_limit_client, seeded_at_value, expected_value, mocked_service_id, mocker):
+ mocker.patch.object(mock_annual_limit_client._redis_client, "get_hash_field", return_value=seeded_at_value)
+ result = mock_annual_limit_client.was_seeded_today(mocked_service_id)
+ assert result == expected_value
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_set_seeded_at(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_seeded_at(mocked_service_id)
+ result = mock_annual_limit_client.get_seeded_at(mocked_service_id)
+ assert result == datetime.utcnow().strftime("%Y-%m-%d")
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+@pytest.mark.parametrize("seeded_at_value, expected_value", [(b"2024-10-25", "2024-10-25"), (None, None)])
+def test_get_seeded_at(mock_annual_limit_client, seeded_at_value, expected_value, mocked_service_id, mocker):
+ mocker.patch.object(mock_annual_limit_client._redis_client, "get_hash_field", return_value=seeded_at_value)
+ result = mock_annual_limit_client.get_seeded_at(mocked_service_id)
+ assert result == expected_value
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id)
+ result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT)
+ assert result == datetime.utcnow().strftime("%Y-%m-%d")
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_set_over_sms_limit(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_over_sms_limit(mocked_service_id)
+ result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, OVER_SMS_LIMIT)
+ assert result == datetime.utcnow().strftime("%Y-%m-%d")
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_set_nearing_email_limit(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_nearing_email_limit(mocked_service_id)
+ result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_EMAIL_LIMIT)
+ assert result == datetime.utcnow().strftime("%Y-%m-%d")
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_set_over_email_limit(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_over_email_limit(mocked_service_id)
+ result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, OVER_EMAIL_LIMIT)
+ assert result == datetime.utcnow().strftime("%Y-%m-%d")
+
+
+def test_increment_sms_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
+
+ mock_annual_limit_client.increment_sms_delivered(mocked_service_id)
+
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 2
+ for field in mock_notification_count_types:
+ if field != SMS_DELIVERED:
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1
+
+
+def test_increment_sms_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
+
+ mock_annual_limit_client.increment_sms_failed(mocked_service_id)
+
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_FAILED) == 2
+ for field in mock_notification_count_types:
+ if field != SMS_FAILED:
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1
+
+
+def test_increment_email_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
+
+ mock_annual_limit_client.increment_email_delivered(mocked_service_id)
+
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_DELIVERED) == 2
+ for field in mock_notification_count_types:
+ if field != EMAIL_DELIVERED:
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1
+
+
+def test_increment_email_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
+ for field in mock_notification_count_types:
+ mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
+
+ mock_annual_limit_client.increment_email_failed(mocked_service_id)
+
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_FAILED) == 2
+ for field in mock_notification_count_types:
+ if field != EMAIL_FAILED:
+ assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_check_has_warning_been_sent(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow())
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_EMAIL_LIMIT, datetime.utcnow())
+
+ assert mock_annual_limit_client.check_has_warning_been_sent(mocked_service_id, "sms") == datetime.utcnow().strftime(
+ "%Y-%m-%d"
+ )
+ assert mock_annual_limit_client.check_has_warning_been_sent(mocked_service_id, "email") == datetime.utcnow().strftime(
+ "%Y-%m-%d"
+ )
+
+
+@freeze_time("2024-10-25 12:00:00.000000")
+def test_check_has_over_limit_been_sent(mock_annual_limit_client, mocked_service_id):
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, OVER_SMS_LIMIT, datetime.utcnow())
+ mock_annual_limit_client.set_annual_limit_status(mocked_service_id, OVER_EMAIL_LIMIT, datetime.utcnow())
+
+ assert mock_annual_limit_client.check_has_over_limit_been_sent(mocked_service_id, "sms") == datetime.utcnow().strftime(
+ "%Y-%m-%d"
+ )
+ assert mock_annual_limit_client.check_has_over_limit_been_sent(mocked_service_id, "email") == datetime.utcnow().strftime(
+ "%Y-%m-%d"
+ )
diff --git a/tests/test_formatters.py b/tests/test_formatters.py
index 648bd6ff..0e6f0ab2 100644
--- a/tests/test_formatters.py
+++ b/tests/test_formatters.py
@@ -351,13 +351,10 @@ def test_hrule(markdown_function, expected):
''
""
''
- ''
- '- one
'
- '- two
'
- '- three
'
+ ''
+ '- one
'
+ '- two
'
+ '- three
'
" "
" | "
"
"
@@ -414,13 +411,10 @@ def test_ordered_list(markdown_function, markdown_input, expected):
''
""
''
- ''
- '- one
'
- '- two
'
- '- three
'
+ ''
+ '- one
'
+ '- two
'
+ '- three
'
" "
" | "
"
"
@@ -506,8 +500,8 @@ def test_pluses_dont_render_as_lists(markdown_function, expected):
"* **title**: description",
''
''
- ''
- '- '
+ '
'
+ '- '
"title: description
|
",
],
),
@@ -1065,7 +1059,7 @@ class TestAddLanguageDivs:
1. item 2
1. item 3
[[/fr]]""",
- f'{EMAIL_P_OPEN_TAG}Le français suis l\'anglais{EMAIL_P_CLOSE_TAG}
{EMAIL_P_OPEN_TAG}bonjour{EMAIL_P_CLOSE_TAG}
', # noqa
+ f'{EMAIL_P_OPEN_TAG}Le français suis l\'anglais{EMAIL_P_CLOSE_TAG}
{EMAIL_P_OPEN_TAG}bonjour{EMAIL_P_CLOSE_TAG}
', # noqa
),
("[[en]]No closing tag", f"{EMAIL_P_OPEN_TAG}[[en]]No closing tag{EMAIL_P_CLOSE_TAG}"),
("No opening tag[[/en]]", f"{EMAIL_P_OPEN_TAG}No opening tag[[/en]]{EMAIL_P_CLOSE_TAG}"),
diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py
index e54457b9..a334cedb 100644
--- a/tests/test_redis_client.py
+++ b/tests/test_redis_client.py
@@ -33,6 +33,27 @@ def better_mocked_redis_client(app):
return redis_client
+@pytest.fixture(scope="function")
+def mocked_hash_structure():
+ return {
+ "key1": {
+ "field1": "value1",
+ "field2": 2,
+ "field3": "value3".encode("utf-8"),
+ },
+ "key2": {
+ "field1": "value1",
+ "field2": 2,
+ "field3": "value3".encode("utf-8"),
+ },
+ "key3": {
+ "field1": "value1",
+ "field2": 2,
+ "field3": "value3".encode("utf-8"),
+ },
+ }
+
+
def build_redis_client(app, mocked_redis_pipeline, mocker):
redis_client = RedisClient()
redis_client.init_app(app)
@@ -40,6 +61,7 @@ def build_redis_client(app, mocked_redis_pipeline, mocker):
mocker.patch.object(redis_client.redis_store, "set")
mocker.patch.object(redis_client.redis_store, "hincrby")
mocker.patch.object(redis_client.redis_store, "hgetall", return_value={b"template-1111": b"8", b"template-2222": b"8"})
+ mocker.patch.object(redis_client.redis_store, "hset")
mocker.patch.object(redis_client.redis_store, "hmset")
mocker.patch.object(redis_client.redis_store, "expire")
mocker.patch.object(redis_client.redis_store, "delete")
@@ -167,44 +189,6 @@ def test_should_build_rate_limit_cache_key(sample_service):
assert rate_limit_cache_key(sample_service.id, "TEST") == "{}-TEST".format(sample_service.id)
-def test_decrement_hash_value_should_decrement_value_by_one_for_key(mocked_redis_client):
- key = "12345"
- value = "template-1111"
-
- mocked_redis_client.decrement_hash_value(key, value, -1)
- mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, -1)
-
-
-def test_incr_hash_value_should_increment_value_by_one_for_key(mocked_redis_client):
- key = "12345"
- value = "template-1111"
-
- mocked_redis_client.increment_hash_value(key, value)
- mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, 1)
-
-
-def test_get_all_from_hash_returns_hash_for_key(mocked_redis_client):
- key = "12345"
- assert mocked_redis_client.get_all_from_hash(key) == {b"template-1111": b"8", b"template-2222": b"8"}
- mocked_redis_client.redis_store.hgetall.assert_called_with(key)
-
-
-def test_set_hash_and_expire(mocked_redis_client):
- key = "hash-key"
- values = {"key": 10}
- mocked_redis_client.set_hash_and_expire(key, values, 1)
- mocked_redis_client.redis_store.hmset.assert_called_with(key, values)
- mocked_redis_client.redis_store.expire.assert_called_with(key, 1)
-
-
-def test_set_hash_and_expire_converts_values_to_valid_types(mocked_redis_client):
- key = "hash-key"
- values = {uuid.UUID(int=0): 10}
- mocked_redis_client.set_hash_and_expire(key, values, 1)
- mocked_redis_client.redis_store.hmset.assert_called_with(key, {"00000000-0000-0000-0000-000000000000": 10})
- mocked_redis_client.redis_store.expire.assert_called_with(key, 1)
-
-
@freeze_time("2001-01-01 12:00:00.000000")
def test_should_add_correct_calls_to_the_pipe(mocked_redis_client, mocked_redis_pipeline):
mocked_redis_client.exceeded_rate_limit("key", 100, 100)
@@ -315,3 +299,131 @@ def test_get_length_of_sorted_set_returns_none_if_not_active(self, better_mocked
better_mocked_redis_client.active = False
ret = better_mocked_redis_client.get_length_of_sorted_set("cache_key", min_score=0, max_score=100)
assert ret == 0
+
+
+class TestRedisHashes:
+ @pytest.mark.parametrize(
+ "hash_key, fields_to_delete, expected_deleted, check_if_no_longer_exists",
+ [
+ ("test:hash:key1", ["field1", "field2"], 2, False), # Delete specific fields in a hash
+ ("test:hash:key1", None, 3, True), # Delete All fields in a specific hash
+ ("test:hash:*", ["field1", "field2"], 6, False), # Delete specific fields in a group of hashes
+ ("test:hash:*", None, 9, True), # Delete All fields in a group of hashes
+ ],
+ )
+ def test_delete_hash_fields(
+ self,
+ better_mocked_redis_client,
+ hash_key,
+ fields_to_delete,
+ expected_deleted,
+ check_if_no_longer_exists,
+ mocked_hash_structure,
+ ):
+ # set up the hashes to be deleted
+ for key, fields in mocked_hash_structure.items():
+ better_mocked_redis_client.bulk_set_hash_fields(key=f"test:hash:{key}", mapping=fields)
+
+ num_deleted = better_mocked_redis_client.delete_hash_fields(hashes=hash_key, fields=fields_to_delete)
+
+ # Deleting all hash fields by pattern
+ if check_if_no_longer_exists and "*" in hash_key:
+ for key in mocked_hash_structure.keys():
+ assert better_mocked_redis_client.redis_store.exists(f"test:hash:{key}") == 0
+ # Deleting a specific hash
+ elif check_if_no_longer_exists:
+ assert better_mocked_redis_client.redis_store.exists(f"test:hash:{hash_key}") == 0
+
+ # Make sure we've deleted the correct number of fields
+ assert sum(num_deleted) == expected_deleted
+
+ def test_get_hash_field(self, mocked_redis_client):
+ key = "12345"
+ field = "template-1111"
+ mocked_redis_client.redis_store.hget = Mock(return_value=b"8")
+ assert mocked_redis_client.get_hash_field(key, field) == b"8"
+ mocked_redis_client.redis_store.hget.assert_called_with(key, field)
+
+ def test_set_hash_value(self, mocked_redis_client):
+ key = "12345"
+ field = "template-1111"
+ value = 8
+ mocked_redis_client.set_hash_value(key, field, value)
+ mocked_redis_client.redis_store.hset.assert_called_with(key, field, value)
+
+ @pytest.mark.parametrize(
+ "hash, updates, expected",
+ [
+ (
+ {
+ "key1": {
+ "field1": "value1",
+ "field2": 2,
+ "field3": "value3".encode("utf-8"),
+ },
+ "key2": {
+ "field1": "value1",
+ "field2": 2,
+ "field3": "value3".encode("utf-8"),
+ },
+ "key3": {
+ "field1": "value1",
+ "field2": 2,
+ "field3": "value3".encode("utf-8"),
+ },
+ },
+ {
+ "field1": "value2",
+ "field2": 3,
+ "field3": "value4".encode("utf-8"),
+ },
+ {
+ b"field1": b"value2",
+ b"field2": b"3",
+ b"field3": b"value4",
+ },
+ )
+ ],
+ )
+ def test_bulk_set_hash_fields(self, better_mocked_redis_client, hash, updates, expected):
+ for key, fields in hash.items():
+ for field, value in fields.items():
+ better_mocked_redis_client.set_hash_value(key, field, value)
+
+ better_mocked_redis_client.bulk_set_hash_fields(pattern="key*", mapping=updates)
+
+ for key, _ in hash.items():
+ assert better_mocked_redis_client.redis_store.hgetall(key) == expected
+
+ def test_decrement_hash_value_should_decrement_value_by_one_for_key(self, mocked_redis_client):
+ key = "12345"
+ value = "template-1111"
+
+ mocked_redis_client.decrement_hash_value(key, value, -1)
+ mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, -1)
+
+ def test_incr_hash_value_should_increment_value_by_one_for_key(self, mocked_redis_client):
+ key = "12345"
+ value = "template-1111"
+
+ mocked_redis_client.increment_hash_value(key, value)
+ mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, 1)
+
+ def test_get_all_from_hash_returns_hash_for_key(self, mocked_redis_client):
+ key = "12345"
+ assert mocked_redis_client.get_all_from_hash(key) == {b"template-1111": b"8", b"template-2222": b"8"}
+ mocked_redis_client.redis_store.hgetall.assert_called_with(key)
+
+ def test_set_hash_and_expire(self, mocked_redis_client):
+ key = "hash-key"
+ values = {"key": 10}
+ mocked_redis_client.set_hash_and_expire(key, values, 1)
+ mocked_redis_client.redis_store.hmset.assert_called_with(key, values)
+ mocked_redis_client.redis_store.expire.assert_called_with(key, 1)
+
+ def test_set_hash_and_expire_converts_values_to_valid_types(self, mocked_redis_client):
+ key = "hash-key"
+ values = {uuid.UUID(int=0): 10}
+ mocked_redis_client.set_hash_and_expire(key, values, 1)
+ mocked_redis_client.redis_store.hmset.assert_called_with(key, {"00000000-0000-0000-0000-000000000000": 10})
+ mocked_redis_client.redis_store.expire.assert_called_with(key, 1)
diff --git a/tests/test_template_types.py b/tests/test_template_types.py
index abf9acc8..6d220b0f 100644
--- a/tests/test_template_types.py
+++ b/tests/test_template_types.py
@@ -2255,3 +2255,34 @@ def test_image_class_applied_to_logo(template_class, filename, expected_html_cla
def test_image_not_present_if_no_logo(template_class):
# can't test that the html doesn't move in utils - tested in template preview instead
assert "