diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py new file mode 100644 index 00000000..54823c50 --- /dev/null +++ b/notifications_utils/clients/redis/annual_limit.py @@ -0,0 +1,130 @@ +"""This module stores daily notification counts and annual limit statuses for a service in Redis.""" + +from datetime import datetime + +from notifications_utils.clients.redis.redis_client import RedisClient + +SMS_DELIVERED = "sms_delivered" +EMAIL_DELIVERED = "email_delivered" +SMS_FAILED = "sms_failed" +EMAIL_FAILED = "email_failed" + +NEAR_SMS_LIMIT = "near_sms_limit" +NEAR_EMAIL_LIMIT = "near_email_limit" +OVER_SMS_LIMIT = "over_sms_limit" +OVER_EMAIL_LIMIT = "over_email_limit" + + +def notifications_key(service_id): + """ + Generates the Redis hash key for storing daily metrics of a service. + """ + return f"annual-limit:{service_id}:notifications" + + +def annual_limit_status_key(service_id): + """ + Generates the Redis hash key for storing annual limit information of a service. + """ + return f"annual-limit:{service_id}:status" + + +def decode_byte_dict(dict: dict): + """ + Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings. + """ + return {key.decode("utf-8"): value.decode("utf-8") for key, value in dict.items()} + + +class RedisAnnualLimit: + def __init__(self, redis: RedisClient): + self._redis_client = redis + + def init_app(self, app, *args, **kwargs): + pass + + def increment_notification_count(self, service_id: str, field: str): + self._redis_client.increment_hash_value(notifications_key(service_id), field) + + def get_notification_count(self, service_id: str, field: str): + """ + Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.) + """ + return int(self._redis_client.get_hash_field(notifications_key(service_id), field)) + + def get_all_notification_counts(self, service_id: str): + """ + Retrieves all daily notification metrics for a service. + """ + return decode_byte_dict(self._redis_client.get_all_from_hash(notifications_key(service_id))) + + def clear_notification_counts(self, service_id: str): + """ + Clears all daily notification metrics for a service. + """ + self._redis_client.expire(notifications_key(service_id), -1) + + def set_annual_limit_status(self, service_id: str, field: str, value: datetime): + """ + Sets the status (e.g., 'nearing_limit', 'over_limit') in the annual limits Redis hash. + """ + self._redis_client.set_hash_value(annual_limit_status_key(service_id), field, value.strftime("%Y-%m-%d")) + + def get_annual_limit_status(self, service_id: str, field: str): + """ + Retrieves the value of a specific annual limit status from the Redis hash. + """ + response = self._redis_client.get_hash_field(annual_limit_status_key(service_id), field) + return response.decode("utf-8") if response is not None else None + + def get_all_annual_limit_statuses(self, service_id: str): + return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id))) + + def clear_annual_limit_statuses(self, service_id: str): + self._redis_client.expire(f"{annual_limit_status_key(service_id)}", -1) + + # Helper methods for daily metrics + def increment_sms_delivered(self, service_id: str): + self.increment_notification_count(service_id, SMS_DELIVERED) + + def increment_sms_failed(self, service_id: str): + self.increment_notification_count(service_id, SMS_FAILED) + + def increment_email_delivered(self, service_id: str): + self.increment_notification_count(service_id, EMAIL_DELIVERED) + + def increment_email_failed(self, service_id: str): + self.increment_notification_count(service_id, EMAIL_FAILED) + + # Helper methods for annual limits + def set_nearing_sms_limit(self, service_id: str): + self.set_annual_limit_status(service_id, NEAR_SMS_LIMIT, datetime.utcnow()) + + def set_nearing_email_limit(self, service_id: str): + self.set_annual_limit_status(service_id, NEAR_EMAIL_LIMIT, datetime.utcnow()) + + def set_over_sms_limit(self, service_id: str): + self.set_annual_limit_status(service_id, OVER_SMS_LIMIT, datetime.utcnow()) + + def set_over_email_limit(self, service_id: str): + self.set_annual_limit_status(service_id, OVER_EMAIL_LIMIT, datetime.utcnow()) + + def check_has_warning_been_sent(self, service_id: str, message_type: str): + """ + Check if an annual limit warning email has been sent to the service. + Returns None if no warning has been sent, otherwise returns the date the + last warning was issued. + When a service's annual limit is increased this value is reset. + """ + field_to_fetch = NEAR_SMS_LIMIT if message_type == "sms" else NEAR_EMAIL_LIMIT + return self.get_annual_limit_status(service_id, field_to_fetch) + + def check_has_over_limit_been_sent(self, service_id: str, message_type: str): + """ + Check if an annual limit exceeded email has been sent to the service. + Returns None if no exceeded email has been sent, otherwise returns the date the + last exceeded email was issued. + When a service's annual limit is increased this value is reset. + """ + field_to_fetch = OVER_SMS_LIMIT if message_type == "sms" else OVER_EMAIL_LIMIT + return self.get_annual_limit_status(service_id, field_to_fetch) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index c85b2678..3c9b5048 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -228,17 +228,42 @@ def get(self, key, raise_exception=False): return None + def set_hash_value(self, key, field, value, raise_exception=False): + key = prepare_value(key) + field = prepare_value(field) + value = prepare_value(value) + + if self.active: + try: + self.redis_store.hset(key, field, value) + except Exception as e: + self.__handle_exception(e, raise_exception, "set_hash_value", key) + def decrement_hash_value(self, key, value, raise_exception=False): return self.increment_hash_value(key, value, raise_exception, incr_by=-1) def increment_hash_value(self, key, value, raise_exception=False, incr_by=1): key = prepare_value(key) value = prepare_value(value) + if self.active: try: return self.redis_store.hincrby(key, value, incr_by) except Exception as e: self.__handle_exception(e, raise_exception, "increment_hash_value", key) + return None + + def get_hash_field(self, key, field, raise_exception=False): + key = prepare_value(key) + field = prepare_value(field) + + if self.active: + try: + return self.redis_store.hget(key, field) + except Exception as e: + self.__handle_exception(e, raise_exception, "get_hash_field", key) + + return None def get_all_from_hash(self, key, raise_exception=False): key = prepare_value(key) diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py new file mode 100644 index 00000000..4b2a22ea --- /dev/null +++ b/tests/test_annual_limit.py @@ -0,0 +1,252 @@ +import uuid +from datetime import datetime +from unittest.mock import Mock + +import fakeredis +import pytest +from freezegun import freeze_time +from notifications_utils.clients.redis.annual_limit import ( + EMAIL_DELIVERED, + EMAIL_FAILED, + NEAR_EMAIL_LIMIT, + NEAR_SMS_LIMIT, + OVER_EMAIL_LIMIT, + OVER_SMS_LIMIT, + SMS_DELIVERED, + SMS_FAILED, + RedisAnnualLimit, + annual_limit_status_key, + notifications_key, +) +from notifications_utils.clients.redis.redis_client import RedisClient + + +@pytest.fixture(scope="function") +def mock_notification_count_types(): + return [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] + + +@pytest.fixture(scope="function") +def mock_annual_limit_statuses(): + return [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] + + +@pytest.fixture(scope="function") +def mocked_redis_pipeline(): + return Mock() + + +@pytest.fixture +def mock_redis_client(app, mocked_redis_pipeline, mocker): + app.config["REDIS_ENABLED"] = True + return build_redis_client(app, mocked_redis_pipeline, mocker) + + +@pytest.fixture(scope="function") +def better_mocked_redis_client(app): + app.config["REDIS_ENABLED"] = True + redis_client = RedisClient() + redis_client.redis_store = fakeredis.FakeStrictRedis(version=6) # type: ignore + redis_client.active = True + return redis_client + + +@pytest.fixture +def redis_annual_limit(mock_redis_client): + return RedisAnnualLimit(mock_redis_client) + + +def build_redis_client(app, mocked_redis_pipeline, mocker): + redis_client = RedisClient() + redis_client.init_app(app) + return redis_client + + +def build_annual_limit_client(mocker, better_mocked_redis_client): + annual_limit_client = RedisAnnualLimit(better_mocked_redis_client) + return annual_limit_client + + +@pytest.fixture(scope="function") +def mock_annual_limit_client(better_mocked_redis_client, mocker): + return RedisAnnualLimit(better_mocked_redis_client) + + +@pytest.fixture(scope="function") +def mocked_service_id(): + return str(uuid.uuid4()) + + +def test_notifications_key(mocked_service_id): + expected_key = f"annual-limit:{mocked_service_id}:notifications" + assert notifications_key(mocked_service_id) == expected_key + + +def test_annual_limits_key(mocked_service_id): + expected_key = f"annual-limit:{mocked_service_id}:status" + assert annual_limit_status_key(mocked_service_id) == expected_key + + +@pytest.mark.parametrize( + "increment_by, metric", + [ + (1, SMS_DELIVERED), + (1, EMAIL_DELIVERED), + (1, SMS_FAILED), + (1, EMAIL_FAILED), + (2, SMS_DELIVERED), + (2, EMAIL_DELIVERED), + (2, SMS_FAILED), + (2, EMAIL_FAILED), + ], +) +def test_increment_notification_count(mock_annual_limit_client, mocked_service_id, metric, increment_by): + for _ in range(increment_by): + mock_annual_limit_client.increment_notification_count(mocked_service_id, metric) + counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id) + assert int(counts[metric]) == increment_by + + +def test_get_notification_count(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.increment_notification_count(mocked_service_id, SMS_DELIVERED) + result = mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) + assert result == 1 + + +def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 + + +def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 + mock_annual_limit_client.clear_notification_counts(mocked_service_id) + assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0 + + +def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow()) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_get_annual_limit_status(mock_annual_limit_client, mocked_service_id): + near_limit_date = datetime.utcnow() + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, near_limit_date) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) + assert result == near_limit_date.strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id): + for status in mock_annual_limit_statuses: + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow()) + assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4 + mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id) + assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0 + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_over_sms_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_over_sms_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, OVER_SMS_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_nearing_email_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_nearing_email_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_EMAIL_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_over_email_limit(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_over_email_limit(mocked_service_id) + result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, OVER_EMAIL_LIMIT) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +def test_increment_sms_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_sms_delivered(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 2 + for field in mock_notification_count_types: + if field != SMS_DELIVERED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +def test_increment_sms_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_sms_failed(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_FAILED) == 2 + for field in mock_notification_count_types: + if field != SMS_FAILED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +def test_increment_email_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_email_delivered(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_DELIVERED) == 2 + for field in mock_notification_count_types: + if field != EMAIL_DELIVERED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +def test_increment_email_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(mocked_service_id, field) + + mock_annual_limit_client.increment_email_failed(mocked_service_id) + + assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_FAILED) == 2 + for field in mock_notification_count_types: + if field != EMAIL_FAILED: + assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_check_has_warning_been_sent(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow()) + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_EMAIL_LIMIT, datetime.utcnow()) + + assert mock_annual_limit_client.check_has_warning_been_sent(mocked_service_id, "sms") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) + assert mock_annual_limit_client.check_has_warning_been_sent(mocked_service_id, "email") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_check_has_over_limit_been_sent(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, OVER_SMS_LIMIT, datetime.utcnow()) + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, OVER_EMAIL_LIMIT, datetime.utcnow()) + + assert mock_annual_limit_client.check_has_over_limit_been_sent(mocked_service_id, "sms") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) + assert mock_annual_limit_client.check_has_over_limit_been_sent(mocked_service_id, "email") == datetime.utcnow().strftime( + "%Y-%m-%d" + ) diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py index e54457b9..4a9a5a62 100644 --- a/tests/test_redis_client.py +++ b/tests/test_redis_client.py @@ -40,6 +40,7 @@ def build_redis_client(app, mocked_redis_pipeline, mocker): mocker.patch.object(redis_client.redis_store, "set") mocker.patch.object(redis_client.redis_store, "hincrby") mocker.patch.object(redis_client.redis_store, "hgetall", return_value={b"template-1111": b"8", b"template-2222": b"8"}) + mocker.patch.object(redis_client.redis_store, "hset") mocker.patch.object(redis_client.redis_store, "hmset") mocker.patch.object(redis_client.redis_store, "expire") mocker.patch.object(redis_client.redis_store, "delete") @@ -167,6 +168,22 @@ def test_should_build_rate_limit_cache_key(sample_service): assert rate_limit_cache_key(sample_service.id, "TEST") == "{}-TEST".format(sample_service.id) +def test_get_hash_field(mocked_redis_client): + key = "12345" + field = "template-1111" + mocked_redis_client.redis_store.hget = Mock(return_value=b"8") + assert mocked_redis_client.get_hash_field(key, field) == b"8" + mocked_redis_client.redis_store.hget.assert_called_with(key, field) + + +def test_set_hash_value(mocked_redis_client): + key = "12345" + field = "template-1111" + value = 8 + mocked_redis_client.set_hash_value(key, field, value) + mocked_redis_client.redis_store.hset.assert_called_with(key, field, value) + + def test_decrement_hash_value_should_decrement_value_by_one_for_key(mocked_redis_client): key = "12345" value = "template-1111"