From 25ef1ae7703ec68b622bf24c4ea08f20ded0bab3 Mon Sep 17 00:00:00 2001 From: William B <7444334+whabanks@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:39:22 -0400 Subject: [PATCH] Fix bug in delete_hash_fields (#337) - delete_hash_fields was occasionally missing fields when used to clear annual limit notification counts - Refactored annual limit client methods that returned numeric data to return 0 instead of None when no data was found in the queried hash --- .github/actions/waffles/requirements.txt | 2 +- .../clients/redis/annual_limit.py | 64 +++++++++-- .../clients/redis/redis_client.py | 12 +-- pyproject.toml | 2 +- tests/test_annual_limit.py | 101 ++++++++++-------- tests/test_redis_client.py | 2 - 6 files changed, 117 insertions(+), 66 deletions(-) diff --git a/.github/actions/waffles/requirements.txt b/.github/actions/waffles/requirements.txt index 1309ae3e..665564fc 100644 --- a/.github/actions/waffles/requirements.txt +++ b/.github/actions/waffles/requirements.txt @@ -1,4 +1,4 @@ docopt==0.6.2 Flask==2.3.3 markupsafe==2.1.5 -git+https://github.com/cds-snc/notifier-utils.git@52.3.9#egg=notifications-utils +git+https://github.com/cds-snc/notifier-utils.git@52.4.0#egg=notifications-utils diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index d508d392..6fa28d3f 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -32,7 +32,7 @@ SMS_FAILED = "sms_failed" EMAIL_FAILED = "email_failed" -NOTIFICATIONS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] +NOTIFICATION_FIELDS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] NEAR_SMS_LIMIT = "near_sms_limit" NEAR_EMAIL_LIMIT = "near_email_limit" @@ -40,7 +40,7 @@ OVER_EMAIL_LIMIT = "over_email_limit" SEEDED_AT = "seeded_at" -STATUSES = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] +STATUS_FIELDS = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] def annual_limit_notifications_key(service_id): @@ -57,16 +57,42 @@ def annual_limit_status_key(service_id): return f"annual-limit:{service_id}:status" -def decode_byte_dict(byte_dict: dict, value_type=str): +def prepare_byte_dict(byte_dict: dict, value_type=str, required_keys=None): """ Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings. """ # Check if expected_value_type is one of the allowed types if value_type not in {int, float, str}: raise ValueError("expected_value_type must be int, float, or str") - if byte_dict is None or not byte_dict.items(): - return None - return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()} + + decoded_dict = ( + {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()} if byte_dict else {} + ) + + if required_keys: + for key in required_keys: + default_value = 0 if value_type in {int, float} else None + decoded_dict.setdefault(key, default_value) + return decoded_dict + + +def init_missing_keys( + required_keys: list, + value_type=str, + incomplete_dict: dict = {}, +): + """Ensures that all expected keys are present in dicts returned from this module. Initializes empty values to defaults if not. + + Args: + incomplete_dict (dict): A dictionary to check for required keys. + required_keys (list): The keys that must be present in the dictionary. + value_type (_type_, optional): The datatype of the values in the dict. Defaults to str. + + Raises: + ValueError: If the value_type is not int, float, or str. + """ + if value_type not in {int, float, str}: + raise ValueError("expected_value_type must be int, float, or str") class RedisAnnualLimit: @@ -91,13 +117,15 @@ def get_notification_count(self, service_id: str, field: str): Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.) """ count = self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field) - return count and int(count.decode("utf-8")) + return 0 if not count else int(count.decode("utf-8")) def get_all_notification_counts(self, service_id: str): """ Retrieves all daily notification metrics for a service. """ - return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int) + return prepare_byte_dict( + self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int, NOTIFICATION_FIELDS + ) def reset_all_notification_counts(self, service_ids=None): """Resets all daily notification metrics. @@ -112,7 +140,7 @@ def reset_all_notification_counts(self, service_ids=None): else [annual_limit_notifications_key(service_id) for service_id in service_ids] ) - self._redis_client.delete_hash_fields(hashes=hashes) + self._redis_client.delete_hash_fields(hashes=hashes, fields=NOTIFICATION_FIELDS) def seed_annual_limit_notifications(self, service_id: str, mapping: dict): """Seeds annual limit notifications for a service. @@ -187,7 +215,7 @@ def get_all_annual_limit_statuses(self, service_id: str): Returns: dict | None: A dictionary of annual limit statuses or None if no statuses are found """ - return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id))) + return prepare_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id)), str, STATUS_FIELDS) def clear_annual_limit_statuses(self, service_id: str): self._redis_client.expire(f"{annual_limit_status_key(service_id)}", -1) @@ -237,3 +265,19 @@ def check_has_over_limit_been_sent(self, service_id: str, message_type: str): """ field_to_fetch = OVER_SMS_LIMIT if message_type == "sms" else OVER_EMAIL_LIMIT return self.get_annual_limit_status(service_id, field_to_fetch) + + def delete_all_annual_limit_hashes(self, service_ids=None): + """ + THIS SHOULD NOT BE CALLED IN CODE. This is a helper method for testing purposes only. + Clears all annual limit hashes in Redis + + Args: + service_ids (Optional): A list of service_ids to clear annual limit hashes for. Clears all services if None. + """ + if not service_ids: + self._redis_client.delete_cache_keys_by_pattern(annual_limit_notifications_key("*")) + self._redis_client.delete_cache_keys_by_pattern(annual_limit_status_key("*")) + else: + for service_id in service_ids: + self._redis_client.delete(annual_limit_notifications_key(service_id)) + self._redis_client.delete(annual_limit_status_key(service_id)) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index c96ff22a..08cb6388 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -1,7 +1,7 @@ import numbers import uuid from time import time -from typing import Any, Dict, Optional +from typing import Any, Dict from flask import current_app from flask_redis import FlaskRedis @@ -82,7 +82,7 @@ def delete_cache_keys_by_pattern(self, pattern): return 0 # TODO: Refactor and simplify this to use HEXPIRE when we upgrade Redis to 7.4.0 - def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None, raise_exception=False): + def delete_hash_fields(self, hashes: (str | list), fields: list, raise_exception=False): """Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely. Args: @@ -98,13 +98,7 @@ def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None # When fields are passed in, use the list as is # When hashes is a list, and no fields are passed in, fetch the fields from the first hash in the list # otherwise we know we're going scan iterate over a pattern so we'll fetch the fields on the first pass in the loop below - fields = ( - [prepare_value(f) for f in fields] - if fields is not None - else self.redis_store.hkeys(hashes[0]) - if isinstance(hashes, list) - else None - ) + fields = [prepare_value(f) for f in fields] # Use a pipeline to atomically delete fields from each hash. pipe = self.redis_store.pipeline() # if hashes is not a list, we're scan iterating over keys matching a pattern diff --git a/pyproject.toml b/pyproject.toml index 2c41ce17..ca4e565a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "notifications-utils" -version = "52.3.9" +version = "52.4.0" description = "Shared python code for Notification - Provides logging utils etc." authors = ["Canadian Digital Service"] license = "MIT license" diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py index c32557c7..c6c6b0e6 100644 --- a/tests/test_annual_limit.py +++ b/tests/test_annual_limit.py @@ -10,10 +10,12 @@ EMAIL_FAILED, NEAR_EMAIL_LIMIT, NEAR_SMS_LIMIT, + NOTIFICATION_FIELDS, OVER_EMAIL_LIMIT, OVER_SMS_LIMIT, SMS_DELIVERED, SMS_FAILED, + STATUS_FIELDS, RedisAnnualLimit, annual_limit_notifications_key, annual_limit_status_key, @@ -21,16 +23,6 @@ from notifications_utils.clients.redis.redis_client import RedisClient -@pytest.fixture(scope="function") -def mock_notification_count_types(): - return [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] - - -@pytest.fixture(scope="function") -def mock_annual_limit_statuses(): - return [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] - - @pytest.fixture(scope="function") def mocked_redis_pipeline(): return Mock() @@ -114,25 +106,32 @@ def test_get_notification_count(mock_annual_limit_client, mocked_service_id): def test_get_notification_count_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id): - assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) is None + assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 0 -def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): - for field in mock_notification_count_types: +def test_get_all_notification_counts(mock_annual_limit_client, mocked_service_id): + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) - assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 + counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id) + assert len(counts) == 4 + assert all(isinstance(value, int) for value in counts.values()) def test_get_all_notification_counts_returns_none_if_fields_do_not_exist(mock_annual_limit_client, mocked_service_id): - assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None + notification_counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id) + assert set(notification_counts.keys()) == set(NOTIFICATION_FIELDS) + assert all(value == 0 for value in notification_counts.values()) -def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): - for field in mock_notification_count_types: +def test_clear_notification_counts(mock_annual_limit_client, mocked_service_id): + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 + mock_annual_limit_client.clear_notification_counts(mocked_service_id) - assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None + counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id) + assert set(counts.keys()) == set(NOTIFICATION_FIELDS) + assert all(value == 0 for value in counts.values()) @pytest.mark.parametrize( @@ -146,16 +145,20 @@ def test_clear_notification_counts(mock_annual_limit_client, mock_notification_c ] ], ) -def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notification_count_types, service_ids): +def test_bulk_reset_notification_counts(mock_annual_limit_client, service_ids): for service_id in service_ids: - for field in mock_notification_count_types: + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(service_id, field) - assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4 - mock_annual_limit_client.reset_all_notification_counts() + counts = mock_annual_limit_client.get_all_notification_counts(service_id) + assert set(counts.keys()) == set(NOTIFICATION_FIELDS) + assert all(value > 0 for value in counts.values()) + mock_annual_limit_client.reset_all_notification_counts() for service_id in service_ids: - assert mock_annual_limit_client.get_all_notification_counts(service_id) is None + counts = mock_annual_limit_client.get_all_notification_counts(service_id) + assert set(counts.keys()) == set(NOTIFICATION_FIELDS) + assert all(value == 0 for value in counts.values()) def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id): @@ -177,23 +180,35 @@ def test_get_annual_limit_status_returns_none_when_fields_do_not_exist(mock_annu @freeze_time("2024-10-25 12:00:00.000000") -def test_get_all_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id): - for status in mock_annual_limit_statuses: +def test_get_all_annual_limit_statuses(mock_annual_limit_client, mocked_service_id): + for status in STATUS_FIELDS: mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow()) - assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4 + + statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) + assert len(statuses) == 4 + assert all(value is not None for value in statuses.values()) def test_get_all_annual_limit_statuses_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id): - assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None + statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) + assert set(statuses.keys()) == set(STATUS_FIELDS) + assert all(value is None for value in statuses.values()) @freeze_time("2024-10-25 12:00:00.000000") -def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id): - for status in mock_annual_limit_statuses: +def test_clear_annual_limit_statuses(mock_annual_limit_client, mocked_service_id): + for status in STATUS_FIELDS: mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow()) - assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4 + + statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) + assert len(statuses) == 4 + assert all(value == "2024-10-25" for value in statuses.values()) + mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id) - assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None + + statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) + assert set(statuses.keys()) == set(STATUS_FIELDS) + assert all(value is None for value in statuses.values()) @freeze_time("2024-10-25 12:00:00.000000") @@ -251,50 +266,50 @@ def test_set_over_email_limit(mock_annual_limit_client, mocked_service_id): assert result == datetime.utcnow().strftime("%Y-%m-%d") -def test_increment_sms_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): - for field in mock_notification_count_types: +def test_increment_sms_delivered(mock_annual_limit_client, mocked_service_id): + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) mock_annual_limit_client.increment_sms_delivered(mocked_service_id) assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 2 - for field in mock_notification_count_types: + for field in NOTIFICATION_FIELDS: if field != SMS_DELIVERED: assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 -def test_increment_sms_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): - for field in mock_notification_count_types: +def test_increment_sms_failed(mock_annual_limit_client, mocked_service_id): + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) mock_annual_limit_client.increment_sms_failed(mocked_service_id) assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_FAILED) == 2 - for field in mock_notification_count_types: + for field in NOTIFICATION_FIELDS: if field != SMS_FAILED: assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 -def test_increment_email_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): - for field in mock_notification_count_types: +def test_increment_email_delivered(mock_annual_limit_client, mocked_service_id): + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) mock_annual_limit_client.increment_email_delivered(mocked_service_id) assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_DELIVERED) == 2 - for field in mock_notification_count_types: + for field in NOTIFICATION_FIELDS: if field != EMAIL_DELIVERED: assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 -def test_increment_email_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): - for field in mock_notification_count_types: +def test_increment_email_failed(mock_annual_limit_client, mocked_service_id): + for field in NOTIFICATION_FIELDS: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) mock_annual_limit_client.increment_email_failed(mocked_service_id) assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_FAILED) == 2 - for field in mock_notification_count_types: + for field in NOTIFICATION_FIELDS: if field != EMAIL_FAILED: assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1 diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py index a334cedb..5a4506e1 100644 --- a/tests/test_redis_client.py +++ b/tests/test_redis_client.py @@ -306,9 +306,7 @@ class TestRedisHashes: "hash_key, fields_to_delete, expected_deleted, check_if_no_longer_exists", [ ("test:hash:key1", ["field1", "field2"], 2, False), # Delete specific fields in a hash - ("test:hash:key1", None, 3, True), # Delete All fields in a specific hash ("test:hash:*", ["field1", "field2"], 6, False), # Delete specific fields in a group of hashes - ("test:hash:*", None, 9, True), # Delete All fields in a group of hashes ], ) def test_delete_hash_fields(