diff --git a/.github/workflows/export_github_data.yml b/.github/workflows/export_github_data.yml index db925884..2cd11121 100644 --- a/.github/workflows/export_github_data.yml +++ b/.github/workflows/export_github_data.yml @@ -14,7 +14,7 @@ jobs: DNS_PROXY_FORWARDTOSENTINEL: "true" DNS_PROXY_LOGANALYTICSWORKSPACEID: ${{ secrets.LOG_ANALYTICS_WORKSPACE_ID }} DNS_PROXY_LOGANALYTICSSHAREDKEY: ${{ secrets.LOG_ANALYTICS_WORKSPACE_KEY }} - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Export Data uses: cds-snc/github-repository-metadata-exporter@main with: diff --git a/.github/workflows/ossf-scorecard.yml b/.github/workflows/ossf-scorecard.yml index 87010d01..556ccb91 100644 --- a/.github/workflows/ossf-scorecard.yml +++ b/.github/workflows/ossf-scorecard.yml @@ -20,7 +20,7 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: persist-credentials: false diff --git a/.github/workflows/s3-backup.yml b/.github/workflows/s3-backup.yml index 6a8e9670..d850d5e0 100644 --- a/.github/workflows/s3-backup.yml +++ b/.github/workflows/s3-backup.yml @@ -10,7 +10,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: fetch-depth: 0 # retrieve all history diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py new file mode 100644 index 00000000..823535a3 --- /dev/null +++ b/notifications_utils/clients/redis/annual_limit.py @@ -0,0 +1,188 @@ +""" +This module stores daily notification counts and annual limit statuses for a service in Redis using a hash structure: + + +annual-limit: { + {service_id}: { + notifications: { + sms_delivered: int, + email_delivered: int, + sms_failed: int, + email_failed: int + }, + status: { + near_sms_limit: Datetime, + near_email_limit: Datetime, + over_sms_limit: Datetime, + over_email_limit: Datetime + seeded_at: Datetime + } + } +} + + +""" + +from datetime import datetime + +from notifications_utils.clients.redis.redis_client import RedisClient + +SMS_DELIVERED = "sms_delivered" +EMAIL_DELIVERED = "email_delivered" +SMS_FAILED = "sms_failed" +EMAIL_FAILED = "email_failed" + +NOTIFICATIONS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] + +NEAR_SMS_LIMIT = "near_sms_limit" +NEAR_EMAIL_LIMIT = "near_email_limit" +OVER_SMS_LIMIT = "over_sms_limit" +OVER_EMAIL_LIMIT = "over_email_limit" +SEEDED_AT = "seeded_at" + +STATUSES = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] + + +def annual_limit_notifications_key(service_id): + """ + Generates the Redis hash key for storing daily metrics of a service. + """ + return f"annual-limit:{service_id}:notifications" + + +def annual_limit_status_key(service_id): + """ + Generates the Redis hash key for storing annual limit information of a service. + """ + return f"annual-limit:{service_id}:status" + + +def decode_byte_dict(dict: dict, value_type=str): + """ + Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings. + """ + # Check if expected_value_type is one of the allowed types + if value_type not in {int, float, str}: + raise ValueError("expected_value_type must be int, float, or str") + return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items() if dict.items()} + + +class RedisAnnualLimit: + def __init__(self, redis: RedisClient): + self._redis_client = redis + + def init_app(self, app, *args, **kwargs): + pass + + def increment_notification_count(self, service_id: str, field: str): + self._redis_client.increment_hash_value(annual_limit_notifications_key(service_id), field) + + def get_notification_count(self, service_id: str, field: str): + """ + Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.) + """ + return int(self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field)) + + def get_all_notification_counts(self, service_id: str): + """ + Retrieves all daily notification metrics for a service. + """ + return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int) + + def reset_all_notification_counts(self, service_ids=None): + """ + Resets all daily notification metrics. + :param: service_ids: list of service_ids to reset, if None, resets all services + """ + hashes = ( + annual_limit_notifications_key("*") + if not service_ids + else [annual_limit_notifications_key(service_id) for service_id in service_ids] + ) + + self._redis_client.delete_hash_fields(hashes=hashes) + + def seed_annual_limit_notifications(self, service_id: str, mapping: dict): + self._redis_client.bulk_set_hash_fields(key=annual_limit_notifications_key(service_id), mapping=mapping) + + def was_seeded_today(self, service_id): + last_seeded_time = self.get_seeded_at(service_id) + return last_seeded_time == datetime.utcnow().strftime("%Y-%m-%d") if last_seeded_time else False + + def get_seeded_at(self, service_id: str): + seeded_at = self._redis_client.get_hash_field(annual_limit_status_key(service_id), SEEDED_AT) + return seeded_at and seeded_at.decode("utf-8") + + def set_seeded_at(self, service_id): + self._redis_client.set_hash_value(annual_limit_status_key(service_id), SEEDED_AT, datetime.utcnow().strftime("%Y-%m-%d")) + + def clear_notification_counts(self, service_id: str): + """ + Clears all daily notification metrics for a service. + """ + self._redis_client.expire(annual_limit_notifications_key(service_id), -1) + + def set_annual_limit_status(self, service_id: str, field: str, value: datetime): + """ + Sets the status (e.g., 'nearing_limit', 'over_limit') in the annual limits Redis hash. + """ + self._redis_client.set_hash_value(annual_limit_status_key(service_id), field, value.strftime("%Y-%m-%d")) + + def get_annual_limit_status(self, service_id: str, field: str): + """ + Retrieves the value of a specific annual limit status from the Redis hash. + """ + response = self._redis_client.get_hash_field(annual_limit_status_key(service_id), field) + return response.decode("utf-8") if response is not None else None + + def get_all_annual_limit_statuses(self, service_id: str): + return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id))) + + def clear_annual_limit_statuses(self, service_id: str): + self._redis_client.expire(f"{annual_limit_status_key(service_id)}", -1) + + # Helper methods for daily metrics + def increment_sms_delivered(self, service_id: str): + self.increment_notification_count(service_id, SMS_DELIVERED) + + def increment_sms_failed(self, service_id: str): + self.increment_notification_count(service_id, SMS_FAILED) + + def increment_email_delivered(self, service_id: str): + self.increment_notification_count(service_id, EMAIL_DELIVERED) + + def increment_email_failed(self, service_id: str): + self.increment_notification_count(service_id, EMAIL_FAILED) + + # Helper methods for annual limits + def set_nearing_sms_limit(self, service_id: str): + self.set_annual_limit_status(service_id, NEAR_SMS_LIMIT, datetime.utcnow()) + + def set_nearing_email_limit(self, service_id: str): + self.set_annual_limit_status(service_id, NEAR_EMAIL_LIMIT, datetime.utcnow()) + + def set_over_sms_limit(self, service_id: str): + self.set_annual_limit_status(service_id, OVER_SMS_LIMIT, datetime.utcnow()) + + def set_over_email_limit(self, service_id: str): + self.set_annual_limit_status(service_id, OVER_EMAIL_LIMIT, datetime.utcnow()) + + def check_has_warning_been_sent(self, service_id: str, message_type: str): + """ + Check if an annual limit warning email has been sent to the service. + Returns None if no warning has been sent, otherwise returns the date the + last warning was issued. + When a service's annual limit is increased this value is reset. + """ + field_to_fetch = NEAR_SMS_LIMIT if message_type == "sms" else NEAR_EMAIL_LIMIT + return self.get_annual_limit_status(service_id, field_to_fetch) + + def check_has_over_limit_been_sent(self, service_id: str, message_type: str): + """ + Check if an annual limit exceeded email has been sent to the service. + Returns None if no exceeded email has been sent, otherwise returns the date the + last exceeded email was issued. + When a service's annual limit is increased this value is reset. + """ + field_to_fetch = OVER_SMS_LIMIT if message_type == "sms" else OVER_EMAIL_LIMIT + return self.get_annual_limit_status(service_id, field_to_fetch) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index c85b2678..3f4831ab 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -1,7 +1,7 @@ import numbers import uuid from time import time -from typing import Any, Dict +from typing import Any, Dict, Optional from flask import current_app from flask_redis import FlaskRedis @@ -81,6 +81,64 @@ def delete_cache_keys_by_pattern(self, pattern): return self.scripts["delete-keys-by-pattern"](args=[pattern]) return 0 + # TODO: Refactor and simplify this to use HEXPIRE when we upgrade Redis to 7.4.0 + def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None, raise_exception=False): + """Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely. + + Args: + hashes (str|list): The hash pattern or list of hash keys to delete fields from. + fields (list): A list of fields to delete from the hashes. If `None`, then all fields are deleted. + + Returns: + _type_: _description_ + """ + if self.active: + try: + hashes = [prepare_value(h) for h in hashes] if isinstance(hashes, list) else prepare_value(hashes) + # When fields are passed in, use the list as is + # When hashes is a list, and no fields are passed in, fetch the fields from the first hash in the list + # otherwise we know we're going scan iterate over a pattern so we'll fetch the fields on the first pass in the loop below + fields = ( + [prepare_value(f) for f in fields] + if fields is not None + else self.redis_store.hkeys(hashes[0]) + if isinstance(hashes, list) + else None + ) + # Use a pipeline to atomically delete fields from each hash. + pipe = self.redis_store.pipeline() + # if hashes is not a list, we're scan iterating over keys matching a pattern + for key in hashes if isinstance(hashes, list) else self.redis_store.scan_iter(hashes): + if not fields: + fields = self.redis_store.hkeys(key) + key = prepare_value(key) + pipe.hdel(key, *fields) + result = pipe.execute() + # TODO: May need to double check that the pipeline result count matches the number of hashes deleted + # and retry any failures + return result + except Exception as e: + self.__handle_exception(e, raise_exception, "expire_hash_fields", hashes) + return False + + def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception=False): + """ + Bulk set hash fields. + :param pattern: the pattern to match keys + :param mappting: the mapping of fields to set + :param raise_exception: True if we should allow the exception to bubble up + """ + if self.active: + try: + if pattern: + for key in self.redis_store.scan_iter(pattern): + self.redis_store.hmset(key, mapping) + if key: + return self.redis_store.hmset(key, mapping) + except Exception as e: + self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", pattern) + return False + def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False): """ Rate limiting. @@ -228,17 +286,44 @@ def get(self, key, raise_exception=False): return None + def set_hash_value(self, key, field, value, raise_exception=False): + key = prepare_value(key) + field = prepare_value(field) + value = prepare_value(value) + + if self.active: + try: + return self.redis_store.hset(key, field, value) + except Exception as e: + self.__handle_exception(e, raise_exception, "set_hash_value", key) + + return None + def decrement_hash_value(self, key, value, raise_exception=False): return self.increment_hash_value(key, value, raise_exception, incr_by=-1) def increment_hash_value(self, key, value, raise_exception=False, incr_by=1): key = prepare_value(key) value = prepare_value(value) + if self.active: try: return self.redis_store.hincrby(key, value, incr_by) except Exception as e: self.__handle_exception(e, raise_exception, "increment_hash_value", key) + return None + + def get_hash_field(self, key, field, raise_exception=False): + key = prepare_value(key) + field = prepare_value(field) + + if self.active: + try: + return self.redis_store.hget(key, field) + except Exception as e: + self.__handle_exception(e, raise_exception, "get_hash_field", key) + + return None def get_all_from_hash(self, key, raise_exception=False): key = prepare_value(key) @@ -248,6 +333,8 @@ def get_all_from_hash(self, key, raise_exception=False): except Exception as e: self.__handle_exception(e, raise_exception, "get_all_from_hash", key) + return None + def set_hash_and_expire(self, key, values, expire_in_seconds, raise_exception=False): key = prepare_value(key) values = {prepare_value(k): prepare_value(v) for k, v in values.items()} @@ -258,6 +345,8 @@ def set_hash_and_expire(self, key, values, expire_in_seconds, raise_exception=Fa except Exception as e: self.__handle_exception(e, raise_exception, "set_hash_and_expire", key) + return None + def expire(self, key, expire_in_seconds, raise_exception=False): key = prepare_value(key) if self.active: diff --git a/notifications_utils/formatters.py b/notifications_utils/formatters.py index 79ee215f..c6c48d67 100644 --- a/notifications_utils/formatters.py +++ b/notifications_utils/formatters.py @@ -417,7 +417,7 @@ def list(self, body, ordered=True): '' "" '" @@ -429,7 +429,7 @@ def list(self, body, ordered=True): '
' - '
    ' + '
      ' "{}" "
    " "
' "" '" @@ -440,8 +440,8 @@ def list(self, body, ordered=True): def list_item(self, text): return ( - '
  • ' + '
  • ' "{}" "
  • " ).format(text.strip()) diff --git a/notifications_utils/jinja_templates/email/email_preview_template.jinja2 b/notifications_utils/jinja_templates/email/email_preview_template.jinja2 index 3224d64a..9a7d60af 100644 --- a/notifications_utils/jinja_templates/email/email_preview_template.jinja2 +++ b/notifications_utils/jinja_templates/email/email_preview_template.jinja2 @@ -75,7 +75,7 @@

    {% endif %} -
    +
    {{ body }}
    diff --git a/notifications_utils/jinja_templates/email/email_template.jinja2 b/notifications_utils/jinja_templates/email/email_template.jinja2 index 5b69eb67..70773033 100644 --- a/notifications_utils/jinja_templates/email/email_template.jinja2 +++ b/notifications_utils/jinja_templates/email/email_template.jinja2 @@ -51,8 +51,8 @@ style="border-collapse: collapse; width:100% !important; max-width: 580px; margin: 0 auto;" >
    -
    ' - '
      ' + '
        ' "{}" "
      " "
    - {{ body|safe }} + + {{ body|safe }}
    diff --git a/notifications_utils/template.py b/notifications_utils/template.py index 5bccecf4..c6e00d77 100644 --- a/notifications_utils/template.py +++ b/notifications_utils/template.py @@ -375,6 +375,8 @@ def __init__( self.allow_html = allow_html self.alt_text_en = alt_text_en self.alt_text_fr = alt_text_fr + self.text_direction_rtl = template.get("text_direction_rtl", False) + # set this again to make sure the correct either utils / downstream local jinja is used # however, don't set if we are in a test environment (to preserve the above mock) if "pytest" not in sys.modules: @@ -416,6 +418,7 @@ def __str__(self): "brand_name": self.brand_name, "alt_text_en": self.alt_text_en, "alt_text_fr": self.alt_text_fr, + "text_direction_rtl": self.text_direction_rtl, } ) @@ -483,6 +486,7 @@ def __init__( self.allow_html = allow_html self.alt_text_en = alt_text_en self.alt_text_fr = alt_text_fr + self.text_direction_rtl = template.get("text_direction_rtl", False) def __str__(self): return Markup( @@ -509,6 +513,7 @@ def __str__(self): "asset_domain": self.asset_domain, "alt_text_en": self.alt_text_en, "alt_text_fr": self.alt_text_fr, + "text_direction_rtl": self.text_direction_rtl, } ) ) diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py new file mode 100644 index 00000000..04adc1b9 --- /dev/null +++ b/tests/test_annual_limit.py @@ -0,0 +1,298 @@ +import uuid +from datetime import datetime +from unittest.mock import Mock + +import fakeredis +import pytest +from freezegun import freeze_time +from notifications_utils.clients.redis.annual_limit import ( + EMAIL_DELIVERED, + EMAIL_FAILED, + NEAR_EMAIL_LIMIT, + NEAR_SMS_LIMIT, + OVER_EMAIL_LIMIT, + OVER_SMS_LIMIT, + SMS_DELIVERED, + SMS_FAILED, + RedisAnnualLimit, + annual_limit_notifications_key, + annual_limit_status_key, +) +from notifications_utils.clients.redis.redis_client import RedisClient + + +@pytest.fixture(scope="function") +def mock_notification_count_types(): + return [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] + + +@pytest.fixture(scope="function") +def mock_annual_limit_statuses(): + return [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] + + +@pytest.fixture(scope="function") +def mocked_redis_pipeline(): + return Mock() + + +@pytest.fixture +def mock_redis_client(app, mocked_redis_pipeline, mocker): + app.config["REDIS_ENABLED"] = True + return build_redis_client(app, mocked_redis_pipeline, mocker) + + +@pytest.fixture(scope="function") +def better_mocked_redis_client(app): + app.config["REDIS_ENABLED"] = True + redis_client = RedisClient() + redis_client.redis_store = fakeredis.FakeStrictRedis(version=6) # type: ignore + redis_client.active = True + return redis_client + + +@pytest.fixture +def redis_annual_limit(mock_redis_client): + return RedisAnnualLimit(mock_redis_client) + + +def build_redis_client(app, mocked_redis_pipeline, mocker): + redis_client = RedisClient() + redis_client.init_app(app) + return redis_client + + +def build_annual_limit_client(mocker, better_mocked_redis_client): + annual_limit_client = RedisAnnualLimit(better_mocked_redis_client) + return annual_limit_client + + +@pytest.fixture(scope="function") +def mock_annual_limit_client(better_mocked_redis_client, mocker): + return RedisAnnualLimit(better_mocked_redis_client) + + +@pytest.fixture(scope="function") +def mocked_service_id(): + return str(uuid.uuid4()) + + +def test_notifications_key(mocked_service_id): + expected_key = f"annual-limit:{mocked_service_id}:notifications" + assert annual_limit_notifications_key(mocked_service_id) == expected_key + + +def test_annual_limits_key(mocked_service_id): + expected_key = f"annual-limit:{mocked_service_id}:status" + assert annual_limit_status_key(mocked_service_id) == expected_key + + +@pytest.mark.parametrize( + "increment_by, metric", + [ + (1, SMS_DELIVERED), + (1, EMAIL_DELIVERED), + (1, SMS_FAILED), + (1, EMAIL_FAILED), + (2, SMS_DELIVERED), + (2, EMAIL_DELIVERED), + (2, SMS_FAILED), + (2, EMAIL_FAILED), + ], +) +def test_increment_notification_count(mock_annual_limit_client, mocked_service_id, metric, increment_by): + for _ in range(increment_by): + mock_annual_limit_client.increment_notification_count(mocked_service_id, metric) + counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id) + assert int(counts[metric]) == increment_by + + +def test_get_notification_count(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.increment_notification_count(mocked_service_id, SMS_DELIVERED) + result = mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) + assert result == 1 + + +def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 + + +def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 + mock_annual_limit_client.clear_notification_counts(mocked_service_id) + assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0 + + +@pytest.mark.parametrize( + "service_ids", + [ + [ + str(uuid.uuid4()), + str(uuid.uuid4()), + str(uuid.uuid4()), + str(uuid.uuid4()), + ] + ], +) +def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notification_count_types, service_ids): + for service_id in service_ids: + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(service_id, field) + assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4 + + mock_annual_limit_client.reset_all_notification_counts() + + for service_id in service_ids: + assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 0 + + +def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow()) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_get_annual_limit_status(mock_annual_limit_client, mocked_service_id): + near_limit_date = datetime.utcnow() + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, near_limit_date) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) + assert result == near_limit_date.strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id): + for status in mock_annual_limit_statuses: + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow()) + assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4 + mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id) + assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0 + + +@freeze_time("2024-10-25 12:00:00.000000") +@pytest.mark.parametrize("seeded_at_value, expected_value", [(b"2024-10-25", True), (None, False)]) +def test_was_seeded_today(mock_annual_limit_client, seeded_at_value, expected_value, mocked_service_id, mocker): + mocker.patch.object(mock_annual_limit_client._redis_client, "get_hash_field", return_value=seeded_at_value) + result = mock_annual_limit_client.was_seeded_today(mocked_service_id) + assert result == expected_value + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_seeded_at(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_seeded_at(mocked_service_id) + result = mock_annual_limit_client.get_seeded_at(mocked_service_id) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +@pytest.mark.parametrize("seeded_at_value, expected_value", [(b"2024-10-25", "2024-10-25"), (None, None)]) +def test_get_seeded_at(mock_annual_limit_client, seeded_at_value, expected_value, mocked_service_id, mocker): + mocker.patch.object(mock_annual_limit_client._redis_client, "get_hash_field", return_value=seeded_at_value) + result = mock_annual_limit_client.get_seeded_at(mocked_service_id) + assert result == expected_value + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_over_sms_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_over_sms_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, OVER_SMS_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_nearing_email_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_nearing_email_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_EMAIL_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_over_email_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_over_email_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, OVER_EMAIL_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +def test_increment_sms_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_sms_delivered(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 2 + for field in mock_notification_count_types: + if field != SMS_DELIVERED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +def test_increment_sms_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_sms_failed(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_FAILED) == 2 + for field in mock_notification_count_types: + if field != SMS_FAILED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +def test_increment_email_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_email_delivered(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_DELIVERED) == 2 + for field in mock_notification_count_types: + if field != EMAIL_DELIVERED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +def test_increment_email_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_email_failed(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_FAILED) == 2 + for field in mock_notification_count_types: + if field != EMAIL_FAILED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_check_has_warning_been_sent(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow()) + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_EMAIL_LIMIT, datetime.utcnow()) + + assert mock_annual_limit_client.check_has_warning_been_sent(mocked_service_id, "sms") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) + assert mock_annual_limit_client.check_has_warning_been_sent(mocked_service_id, "email") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_check_has_over_limit_been_sent(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, OVER_SMS_LIMIT, datetime.utcnow()) + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, OVER_EMAIL_LIMIT, datetime.utcnow()) + + assert mock_annual_limit_client.check_has_over_limit_been_sent(mocked_service_id, "sms") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) + assert mock_annual_limit_client.check_has_over_limit_been_sent(mocked_service_id, "email") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) diff --git a/tests/test_formatters.py b/tests/test_formatters.py index 648bd6ff..0e6f0ab2 100644 --- a/tests/test_formatters.py +++ b/tests/test_formatters.py @@ -351,13 +351,10 @@ def test_hrule(markdown_function, expected): '' "" '" "" @@ -414,13 +411,10 @@ def test_ordered_list(markdown_function, markdown_input, expected): '
    ' - '
      ' - '
    1. one
    2. ' - '
    3. two
    4. ' - '
    5. three
    6. ' + '
        ' + '
      1. one
      2. ' + '
      3. two
      4. ' + '
      5. three
      6. ' "
      " "
    ' "" '" "" @@ -506,8 +500,8 @@ def test_pluses_dont_render_as_lists(markdown_function, expected): "* **title**: description", '
    ' - '
      ' - '
    • one
    • ' - '
    • two
    • ' - '
    • three
    • ' + '
        ' + '
      • one
      • ' + '
      • two
      • ' + '
      • three
      • ' "
      " "
    ' '
    ' - '
      ' - '
    • ' + '
        ' + '
      • ' "title: description
    ", ], ), @@ -1065,7 +1059,7 @@ class TestAddLanguageDivs: 1. item 2 1. item 3 [[/fr]]""", - f'
    {EMAIL_P_OPEN_TAG}Le français suis l\'anglais{EMAIL_P_CLOSE_TAG}
    • item 1
    • item 2
    • item 3
    {EMAIL_P_OPEN_TAG}bonjour{EMAIL_P_CLOSE_TAG}
    1. item 1
    2. item 2
    3. item 3
    ', # noqa + f'
    {EMAIL_P_OPEN_TAG}Le français suis l\'anglais{EMAIL_P_CLOSE_TAG}
    • item 1
    • item 2
    • item 3
    {EMAIL_P_OPEN_TAG}bonjour{EMAIL_P_CLOSE_TAG}
    1. item 1
    2. item 2
    3. item 3
    ', # noqa ), ("[[en]]No closing tag", f"{EMAIL_P_OPEN_TAG}[[en]]No closing tag{EMAIL_P_CLOSE_TAG}"), ("No opening tag[[/en]]", f"{EMAIL_P_OPEN_TAG}No opening tag[[/en]]{EMAIL_P_CLOSE_TAG}"), diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py index e54457b9..a334cedb 100644 --- a/tests/test_redis_client.py +++ b/tests/test_redis_client.py @@ -33,6 +33,27 @@ def better_mocked_redis_client(app): return redis_client +@pytest.fixture(scope="function") +def mocked_hash_structure(): + return { + "key1": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key2": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key3": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + } + + def build_redis_client(app, mocked_redis_pipeline, mocker): redis_client = RedisClient() redis_client.init_app(app) @@ -40,6 +61,7 @@ def build_redis_client(app, mocked_redis_pipeline, mocker): mocker.patch.object(redis_client.redis_store, "set") mocker.patch.object(redis_client.redis_store, "hincrby") mocker.patch.object(redis_client.redis_store, "hgetall", return_value={b"template-1111": b"8", b"template-2222": b"8"}) + mocker.patch.object(redis_client.redis_store, "hset") mocker.patch.object(redis_client.redis_store, "hmset") mocker.patch.object(redis_client.redis_store, "expire") mocker.patch.object(redis_client.redis_store, "delete") @@ -167,44 +189,6 @@ def test_should_build_rate_limit_cache_key(sample_service): assert rate_limit_cache_key(sample_service.id, "TEST") == "{}-TEST".format(sample_service.id) -def test_decrement_hash_value_should_decrement_value_by_one_for_key(mocked_redis_client): - key = "12345" - value = "template-1111" - - mocked_redis_client.decrement_hash_value(key, value, -1) - mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, -1) - - -def test_incr_hash_value_should_increment_value_by_one_for_key(mocked_redis_client): - key = "12345" - value = "template-1111" - - mocked_redis_client.increment_hash_value(key, value) - mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, 1) - - -def test_get_all_from_hash_returns_hash_for_key(mocked_redis_client): - key = "12345" - assert mocked_redis_client.get_all_from_hash(key) == {b"template-1111": b"8", b"template-2222": b"8"} - mocked_redis_client.redis_store.hgetall.assert_called_with(key) - - -def test_set_hash_and_expire(mocked_redis_client): - key = "hash-key" - values = {"key": 10} - mocked_redis_client.set_hash_and_expire(key, values, 1) - mocked_redis_client.redis_store.hmset.assert_called_with(key, values) - mocked_redis_client.redis_store.expire.assert_called_with(key, 1) - - -def test_set_hash_and_expire_converts_values_to_valid_types(mocked_redis_client): - key = "hash-key" - values = {uuid.UUID(int=0): 10} - mocked_redis_client.set_hash_and_expire(key, values, 1) - mocked_redis_client.redis_store.hmset.assert_called_with(key, {"00000000-0000-0000-0000-000000000000": 10}) - mocked_redis_client.redis_store.expire.assert_called_with(key, 1) - - @freeze_time("2001-01-01 12:00:00.000000") def test_should_add_correct_calls_to_the_pipe(mocked_redis_client, mocked_redis_pipeline): mocked_redis_client.exceeded_rate_limit("key", 100, 100) @@ -315,3 +299,131 @@ def test_get_length_of_sorted_set_returns_none_if_not_active(self, better_mocked better_mocked_redis_client.active = False ret = better_mocked_redis_client.get_length_of_sorted_set("cache_key", min_score=0, max_score=100) assert ret == 0 + + +class TestRedisHashes: + @pytest.mark.parametrize( + "hash_key, fields_to_delete, expected_deleted, check_if_no_longer_exists", + [ + ("test:hash:key1", ["field1", "field2"], 2, False), # Delete specific fields in a hash + ("test:hash:key1", None, 3, True), # Delete All fields in a specific hash + ("test:hash:*", ["field1", "field2"], 6, False), # Delete specific fields in a group of hashes + ("test:hash:*", None, 9, True), # Delete All fields in a group of hashes + ], + ) + def test_delete_hash_fields( + self, + better_mocked_redis_client, + hash_key, + fields_to_delete, + expected_deleted, + check_if_no_longer_exists, + mocked_hash_structure, + ): + # set up the hashes to be deleted + for key, fields in mocked_hash_structure.items(): + better_mocked_redis_client.bulk_set_hash_fields(key=f"test:hash:{key}", mapping=fields) + + num_deleted = better_mocked_redis_client.delete_hash_fields(hashes=hash_key, fields=fields_to_delete) + + # Deleting all hash fields by pattern + if check_if_no_longer_exists and "*" in hash_key: + for key in mocked_hash_structure.keys(): + assert better_mocked_redis_client.redis_store.exists(f"test:hash:{key}") == 0 + # Deleting a specific hash + elif check_if_no_longer_exists: + assert better_mocked_redis_client.redis_store.exists(f"test:hash:{hash_key}") == 0 + + # Make sure we've deleted the correct number of fields + assert sum(num_deleted) == expected_deleted + + def test_get_hash_field(self, mocked_redis_client): + key = "12345" + field = "template-1111" + mocked_redis_client.redis_store.hget = Mock(return_value=b"8") + assert mocked_redis_client.get_hash_field(key, field) == b"8" + mocked_redis_client.redis_store.hget.assert_called_with(key, field) + + def test_set_hash_value(self, mocked_redis_client): + key = "12345" + field = "template-1111" + value = 8 + mocked_redis_client.set_hash_value(key, field, value) + mocked_redis_client.redis_store.hset.assert_called_with(key, field, value) + + @pytest.mark.parametrize( + "hash, updates, expected", + [ + ( + { + "key1": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key2": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key3": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + }, + { + "field1": "value2", + "field2": 3, + "field3": "value4".encode("utf-8"), + }, + { + b"field1": b"value2", + b"field2": b"3", + b"field3": b"value4", + }, + ) + ], + ) + def test_bulk_set_hash_fields(self, better_mocked_redis_client, hash, updates, expected): + for key, fields in hash.items(): + for field, value in fields.items(): + better_mocked_redis_client.set_hash_value(key, field, value) + + better_mocked_redis_client.bulk_set_hash_fields(pattern="key*", mapping=updates) + + for key, _ in hash.items(): + assert better_mocked_redis_client.redis_store.hgetall(key) == expected + + def test_decrement_hash_value_should_decrement_value_by_one_for_key(self, mocked_redis_client): + key = "12345" + value = "template-1111" + + mocked_redis_client.decrement_hash_value(key, value, -1) + mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, -1) + + def test_incr_hash_value_should_increment_value_by_one_for_key(self, mocked_redis_client): + key = "12345" + value = "template-1111" + + mocked_redis_client.increment_hash_value(key, value) + mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, 1) + + def test_get_all_from_hash_returns_hash_for_key(self, mocked_redis_client): + key = "12345" + assert mocked_redis_client.get_all_from_hash(key) == {b"template-1111": b"8", b"template-2222": b"8"} + mocked_redis_client.redis_store.hgetall.assert_called_with(key) + + def test_set_hash_and_expire(self, mocked_redis_client): + key = "hash-key" + values = {"key": 10} + mocked_redis_client.set_hash_and_expire(key, values, 1) + mocked_redis_client.redis_store.hmset.assert_called_with(key, values) + mocked_redis_client.redis_store.expire.assert_called_with(key, 1) + + def test_set_hash_and_expire_converts_values_to_valid_types(self, mocked_redis_client): + key = "hash-key" + values = {uuid.UUID(int=0): 10} + mocked_redis_client.set_hash_and_expire(key, values, 1) + mocked_redis_client.redis_store.hmset.assert_called_with(key, {"00000000-0000-0000-0000-000000000000": 10}) + mocked_redis_client.redis_store.expire.assert_called_with(key, 1) diff --git a/tests/test_template_types.py b/tests/test_template_types.py index abf9acc8..6d220b0f 100644 --- a/tests/test_template_types.py +++ b/tests/test_template_types.py @@ -2255,3 +2255,34 @@ def test_image_class_applied_to_logo(template_class, filename, expected_html_cla def test_image_not_present_if_no_logo(template_class): # can't test that the html doesn't move in utils - tested in template preview instead assert "