From 9574f24a7d0c3745db7e52e9e7cfb15c5edd6b39 Mon Sep 17 00:00:00 2001 From: wbanks Date: Mon, 4 Nov 2024 15:29:05 -0500 Subject: [PATCH 1/8] Add method to bulk reset notification counts - Added new redis_client base method to bulk update hash values by pattern --- .github/actions/waffles/requirements.txt | 2 +- .../clients/redis/annual_limit.py | 10 +++++ .../clients/redis/redis_client.py | 14 ++++++ pyproject.toml | 2 +- tests/test_annual_limit.py | 25 +++++++++++ tests/test_redis_client.py | 45 +++++++++++++++++++ 6 files changed, 96 insertions(+), 2 deletions(-) diff --git a/.github/actions/waffles/requirements.txt b/.github/actions/waffles/requirements.txt index 5b3b8350..80b4baf0 100644 --- a/.github/actions/waffles/requirements.txt +++ b/.github/actions/waffles/requirements.txt @@ -1,4 +1,4 @@ docopt==0.6.2 Flask==2.3.3 markupsafe==2.1.5 -git+https://github.com/cds-snc/notifier-utils.git@52.3.6#egg=notifications-utils +git+https://github.com/cds-snc/notifier-utils.git@52.3.7#egg=notifications-utils diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 54823c50..58f2f17c 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -58,6 +58,16 @@ def get_all_notification_counts(self, service_id: str): """ return decode_byte_dict(self._redis_client.get_all_from_hash(notifications_key(service_id))) + def reset_all_notification_counts(self): + """ + Resets all daily notification metrics for all services. Uses non-blocking scan_iter method to avoid locking the Redis server. + """ + pattern = notifications_key("*") # All notification keys regardless of service_id + + self._redis_client.bulk_set_hash_fields( + pattern, ({"email_delivered": 0, "email_failed": 0, "sms_delivered": 0, "sms_failed": 0}) + ) + def clear_notification_counts(self, service_id: str): """ Clears all daily notification metrics for a service. diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index 3c9b5048..259f084f 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -81,6 +81,20 @@ def delete_cache_keys_by_pattern(self, pattern): return self.scripts["delete-keys-by-pattern"](args=[pattern]) return 0 + def bulk_set_hash_fields(self, pattern, mapping, raise_exception=False): + """ + Bulk set hash fields. + :param pattern: the pattern to match keys + :param mappting: the mapping of fields to set + :param raise_exception: True if we should allow the exception to bubble up + """ + if self.active: + try: + for key in self.redis_store.scan_iter(pattern): + self.redis_store.hmset(key, mapping) + except Exception as e: + self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", pattern) + def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False): """ Rate limiting. diff --git a/pyproject.toml b/pyproject.toml index 5688b4a5..4aa32641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "notifications-utils" -version = "52.3.6" +version = "52.3.7" description = "Shared python code for Notification - Provides logging utils etc." authors = ["Canadian Digital Service"] license = "MIT license" diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py index 4b2a22ea..62033ef9 100644 --- a/tests/test_annual_limit.py +++ b/tests/test_annual_limit.py @@ -127,6 +127,31 @@ def test_clear_notification_counts(mock_annual_limit_client, mock_notification_c assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0 +@pytest.mark.parametrize( + "service_ids", + [ + [ + str(uuid.uuid4()), + str(uuid.uuid4()), + str(uuid.uuid4()), + str(uuid.uuid4()), + ] + ], +) +def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notification_count_types, service_ids): + for service_id in service_ids: + for field in mock_notification_count_types: + mock_annual_limit_client.increment_notification_count(service_id, field) + assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4 + + mock_annual_limit_client.reset_all_notification_counts() + + for service_id in service_ids: + assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4 + for field in mock_notification_count_types: + assert mock_annual_limit_client.get_notification_count(service_id, field) == 0 + + def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id): mock_annual_limit_client.set_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT, datetime.utcnow()) result = mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py index 4a9a5a62..f302a289 100644 --- a/tests/test_redis_client.py +++ b/tests/test_redis_client.py @@ -184,6 +184,51 @@ def test_set_hash_value(mocked_redis_client): mocked_redis_client.redis_store.hset.assert_called_with(key, field, value) +@pytest.mark.parametrize( + "hash, updates, expected", + [ + ( + { + "key1": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key2": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key3": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + }, + { + "field1": "value2", + "field2": 3, + "field3": "value4".encode("utf-8"), + }, + { + b"field1": b"value2", + b"field2": b"3", + b"field3": b"value4", + }, + ) + ], +) +def test_bulk_set_hash_fields(better_mocked_redis_client, hash, updates, expected): + for key, fields in hash.items(): + for field, value in fields.items(): + better_mocked_redis_client.set_hash_value(key, field, value) + + better_mocked_redis_client.bulk_set_hash_fields("key*", mapping=updates) + + for key, _ in hash.items(): + assert better_mocked_redis_client.redis_store.hgetall(key) == expected + + def test_decrement_hash_value_should_decrement_value_by_one_for_key(mocked_redis_client): key = "12345" value = "template-1111" From b9ad2e1553b5fe9c86a4ad08850d7cc8c6ba9810 Mon Sep 17 00:00:00 2001 From: wbanks Date: Tue, 5 Nov 2024 17:13:20 -0500 Subject: [PATCH 2/8] Rework resetting notification counts - create_nightly_notification_status_for_day processes services in chunks, so we'll also atomically reset cache values in chunks too - Added doc string explaining hash structure in the annual limit client --- .../clients/redis/annual_limit.py | 65 +++++++++++++++---- .../clients/redis/redis_client.py | 56 +++++++++++++++- 2 files changed, 106 insertions(+), 15 deletions(-) diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 58f2f17c..2f6b25d8 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -1,4 +1,27 @@ -"""This module stores daily notification counts and annual limit statuses for a service in Redis.""" +""" +This module stores daily notification counts and annual limit statuses for a service in Redis using a hash structure: + + +annual-limit: { + {service_id}: { + notifications: { + sms_delivered: int, + email_delivered: int, + sms_failed: int, + email_failed: int + }, + status: { + near_sms_limit: Datetime, + near_email_limit: Datetime, + over_sms_limit: Datetime, + over_email_limit: Datetime + seeded_at: Datetime + } + } +} + + +""" from datetime import datetime @@ -9,13 +32,18 @@ SMS_FAILED = "sms_failed" EMAIL_FAILED = "email_failed" +NOTIFICATIONS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED] + NEAR_SMS_LIMIT = "near_sms_limit" NEAR_EMAIL_LIMIT = "near_email_limit" OVER_SMS_LIMIT = "over_sms_limit" OVER_EMAIL_LIMIT = "over_email_limit" +SEEDED_AT = "seeded_at" + +STATUSES = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT] -def notifications_key(service_id): +def annual_limit_notifications_key(service_id): """ Generates the Redis hash key for storing daily metrics of a service. """ @@ -44,35 +72,48 @@ def init_app(self, app, *args, **kwargs): pass def increment_notification_count(self, service_id: str, field: str): - self._redis_client.increment_hash_value(notifications_key(service_id), field) + self._redis_client.increment_hash_value(annual_limit_notifications_key(service_id), field) def get_notification_count(self, service_id: str, field: str): """ Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.) """ - return int(self._redis_client.get_hash_field(notifications_key(service_id), field)) + return int(self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field)) def get_all_notification_counts(self, service_id: str): """ Retrieves all daily notification metrics for a service. """ - return decode_byte_dict(self._redis_client.get_all_from_hash(notifications_key(service_id))) + return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id))) - def reset_all_notification_counts(self): + def reset_all_notification_counts(self, service_ids=None): """ - Resets all daily notification metrics for all services. Uses non-blocking scan_iter method to avoid locking the Redis server. + Resets all daily notification metrics. + :param: service_ids: list of service_ids to reset, if None, resets all services """ - pattern = notifications_key("*") # All notification keys regardless of service_id - - self._redis_client.bulk_set_hash_fields( - pattern, ({"email_delivered": 0, "email_failed": 0, "sms_delivered": 0, "sms_failed": 0}) + hashes = ( + annual_limit_notifications_key("*") + if not service_ids + else [annual_limit_notifications_key(service_id) for service_id in service_ids] ) + self._redis_client.delete_hash_fields(hashes=hashes) + + def was_seeded_today(self, service_id): + last_seeded_time = self.get_seeded_at(service_id) + return last_seeded_time == datetime.utcnow().strftime("%Y-%m-%d") if last_seeded_time else False + + def get_seeded_at(self, service_id: str): + return self._redis_client.get_hash_field(annual_limit_status_key(service_id), SEEDED_AT).decode("utf-8") + + def set_seeded_at(self, service_id): + self._redis_client.set_hash_value(annual_limit_status_key(service_id), SEEDED_AT, datetime.utcnow().strftime("%Y-%m-%d")) + def clear_notification_counts(self, service_id: str): """ Clears all daily notification metrics for a service. """ - self._redis_client.expire(notifications_key(service_id), -1) + self._redis_client.expire(annual_limit_notifications_key(service_id), -1) def set_annual_limit_status(self, service_id: str, field: str, value: datetime): """ diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index 259f084f..1c5f7fb5 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -81,7 +81,57 @@ def delete_cache_keys_by_pattern(self, pattern): return self.scripts["delete-keys-by-pattern"](args=[pattern]) return 0 - def bulk_set_hash_fields(self, pattern, mapping, raise_exception=False): + def delete_hash_fields(self, hashes: (str | list), fields: list = None, raise_exception=False): + """Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely. + + Args: + hashes (str|list): The hash pattern or list of hash keys to delete fields from. + fields (list): A list of fields to delete from the hashes. If `None`, then all fields are deleted. + + Returns: + _type_: _description_ + """ + if self.active: + try: + hashes = [prepare_value(h) for h in hashes] if isinstance(hashes, list) else prepare_value(hashes) + # When fields are passed in, use the list as is + # When hashes is a list, and no fields are passed in, fetch the fields from the first hash in the list + # otherwise we know we're going scan iterate over a pattern so we'll fetch the fields on the first pass in the loop below + fields = ( + [prepare_value(f) for f in fields] + if fields is not None + else self.redis_store.hkeys(hashes[0]) + if isinstance(hashes, list) + else None + ) + # Use a pipeline to atomically delete fields from each hash. + pipe = self.redis_store.pipeline() + # if hashes is not a list, we're scan iterating over keys matching a pattern + for key in hashes if isinstance(hashes, list) else self.redis_store.scan_iter(hashes): + if not fields: + fields = self.redis_store.hkeys(key) + key = prepare_value(key) + pipe.hdel(key, *fields) + result = pipe.execute() + return result + except Exception as e: + self.__handle_exception(e, raise_exception, "expire_hash_fields", hashes) + + def set_hash_fields_by_pattern_or_keys(self, mapping, keys: str | list = None, raise_exception=False): + """ + Bulk set hash fields. + :param pattern: the pattern to match keys or a list of keys to set + :param mappting: the mapping of fields to set + :param raise_exception: True if we should allow the exception to bubble up + """ + if self.active: + try: + for key in self.redis_store.hscan_iter(keys): + self.redis_store.hmset(key, mapping) + except Exception as e: + self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", keys) + + def set_hash_fields_by_keys(self, keys, mapping, raise_exception=False): """ Bulk set hash fields. :param pattern: the pattern to match keys @@ -90,10 +140,10 @@ def bulk_set_hash_fields(self, pattern, mapping, raise_exception=False): """ if self.active: try: - for key in self.redis_store.scan_iter(pattern): + for key in self.redis_store.scan_iter(keys): self.redis_store.hmset(key, mapping) except Exception as e: - self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", pattern) + self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", keys) def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False): """ From 138c7fb49f832db00642981d40777c9e8a7c9dea Mon Sep 17 00:00:00 2001 From: wbanks Date: Tue, 5 Nov 2024 17:17:48 -0500 Subject: [PATCH 3/8] Cleanup --- .../clients/redis/redis_client.py | 32 +++---------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index 1c5f7fb5..544f5801 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -113,37 +113,13 @@ def delete_hash_fields(self, hashes: (str | list), fields: list = None, raise_ex key = prepare_value(key) pipe.hdel(key, *fields) result = pipe.execute() + # TODO: May need to double check that the pipeline result count matches the number of hashes deleted + # and retry any failures return result except Exception as e: self.__handle_exception(e, raise_exception, "expire_hash_fields", hashes) - - def set_hash_fields_by_pattern_or_keys(self, mapping, keys: str | list = None, raise_exception=False): - """ - Bulk set hash fields. - :param pattern: the pattern to match keys or a list of keys to set - :param mappting: the mapping of fields to set - :param raise_exception: True if we should allow the exception to bubble up - """ - if self.active: - try: - for key in self.redis_store.hscan_iter(keys): - self.redis_store.hmset(key, mapping) - except Exception as e: - self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", keys) - - def set_hash_fields_by_keys(self, keys, mapping, raise_exception=False): - """ - Bulk set hash fields. - :param pattern: the pattern to match keys - :param mappting: the mapping of fields to set - :param raise_exception: True if we should allow the exception to bubble up - """ - if self.active: - try: - for key in self.redis_store.scan_iter(keys): - self.redis_store.hmset(key, mapping) - except Exception as e: - self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", keys) + else: + return False def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False): """ From a7fb0fa8dc54c96e73ca2e114805856b73ae67c5 Mon Sep 17 00:00:00 2001 From: wbanks Date: Wed, 6 Nov 2024 12:23:28 -0500 Subject: [PATCH 4/8] Add tests for delete_hash_fields --- .../clients/redis/redis_client.py | 18 ++ tests/test_annual_limit.py | 4 +- tests/test_redis_client.py | 248 +++++++++++------- 3 files changed, 169 insertions(+), 101 deletions(-) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index 544f5801..c22d28e7 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -81,6 +81,7 @@ def delete_cache_keys_by_pattern(self, pattern): return self.scripts["delete-keys-by-pattern"](args=[pattern]) return 0 + # TODO: Refactor and simplify this to use HEXPIRE when we upgrade Redis to 7.4.0 def delete_hash_fields(self, hashes: (str | list), fields: list = None, raise_exception=False): """Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely. @@ -121,6 +122,23 @@ def delete_hash_fields(self, hashes: (str | list), fields: list = None, raise_ex else: return False + def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception=False): + """ + Bulk set hash fields. + :param pattern: the pattern to match keys + :param mappting: the mapping of fields to set + :param raise_exception: True if we should allow the exception to bubble up + """ + if self.active: + try: + if pattern: + for key in self.redis_store.scan_iter(pattern): + self.redis_store.hmset(key, mapping) + if key: + self.redis_store.hmset(key, mapping) + except Exception as e: + self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", pattern) + def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False): """ Rate limiting. diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py index 62033ef9..9720f590 100644 --- a/tests/test_annual_limit.py +++ b/tests/test_annual_limit.py @@ -15,8 +15,8 @@ SMS_DELIVERED, SMS_FAILED, RedisAnnualLimit, + annual_limit_notifications_key, annual_limit_status_key, - notifications_key, ) from notifications_utils.clients.redis.redis_client import RedisClient @@ -79,7 +79,7 @@ def mocked_service_id(): def test_notifications_key(mocked_service_id): expected_key = f"annual-limit:{mocked_service_id}:notifications" - assert notifications_key(mocked_service_id) == expected_key + assert annual_limit_notifications_key(mocked_service_id) == expected_key def test_annual_limits_key(mocked_service_id): diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py index f302a289..a334cedb 100644 --- a/tests/test_redis_client.py +++ b/tests/test_redis_client.py @@ -33,6 +33,27 @@ def better_mocked_redis_client(app): return redis_client +@pytest.fixture(scope="function") +def mocked_hash_structure(): + return { + "key1": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key2": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key3": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + } + + def build_redis_client(app, mocked_redis_pipeline, mocker): redis_client = RedisClient() redis_client.init_app(app) @@ -168,105 +189,6 @@ def test_should_build_rate_limit_cache_key(sample_service): assert rate_limit_cache_key(sample_service.id, "TEST") == "{}-TEST".format(sample_service.id) -def test_get_hash_field(mocked_redis_client): - key = "12345" - field = "template-1111" - mocked_redis_client.redis_store.hget = Mock(return_value=b"8") - assert mocked_redis_client.get_hash_field(key, field) == b"8" - mocked_redis_client.redis_store.hget.assert_called_with(key, field) - - -def test_set_hash_value(mocked_redis_client): - key = "12345" - field = "template-1111" - value = 8 - mocked_redis_client.set_hash_value(key, field, value) - mocked_redis_client.redis_store.hset.assert_called_with(key, field, value) - - -@pytest.mark.parametrize( - "hash, updates, expected", - [ - ( - { - "key1": { - "field1": "value1", - "field2": 2, - "field3": "value3".encode("utf-8"), - }, - "key2": { - "field1": "value1", - "field2": 2, - "field3": "value3".encode("utf-8"), - }, - "key3": { - "field1": "value1", - "field2": 2, - "field3": "value3".encode("utf-8"), - }, - }, - { - "field1": "value2", - "field2": 3, - "field3": "value4".encode("utf-8"), - }, - { - b"field1": b"value2", - b"field2": b"3", - b"field3": b"value4", - }, - ) - ], -) -def test_bulk_set_hash_fields(better_mocked_redis_client, hash, updates, expected): - for key, fields in hash.items(): - for field, value in fields.items(): - better_mocked_redis_client.set_hash_value(key, field, value) - - better_mocked_redis_client.bulk_set_hash_fields("key*", mapping=updates) - - for key, _ in hash.items(): - assert better_mocked_redis_client.redis_store.hgetall(key) == expected - - -def test_decrement_hash_value_should_decrement_value_by_one_for_key(mocked_redis_client): - key = "12345" - value = "template-1111" - - mocked_redis_client.decrement_hash_value(key, value, -1) - mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, -1) - - -def test_incr_hash_value_should_increment_value_by_one_for_key(mocked_redis_client): - key = "12345" - value = "template-1111" - - mocked_redis_client.increment_hash_value(key, value) - mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, 1) - - -def test_get_all_from_hash_returns_hash_for_key(mocked_redis_client): - key = "12345" - assert mocked_redis_client.get_all_from_hash(key) == {b"template-1111": b"8", b"template-2222": b"8"} - mocked_redis_client.redis_store.hgetall.assert_called_with(key) - - -def test_set_hash_and_expire(mocked_redis_client): - key = "hash-key" - values = {"key": 10} - mocked_redis_client.set_hash_and_expire(key, values, 1) - mocked_redis_client.redis_store.hmset.assert_called_with(key, values) - mocked_redis_client.redis_store.expire.assert_called_with(key, 1) - - -def test_set_hash_and_expire_converts_values_to_valid_types(mocked_redis_client): - key = "hash-key" - values = {uuid.UUID(int=0): 10} - mocked_redis_client.set_hash_and_expire(key, values, 1) - mocked_redis_client.redis_store.hmset.assert_called_with(key, {"00000000-0000-0000-0000-000000000000": 10}) - mocked_redis_client.redis_store.expire.assert_called_with(key, 1) - - @freeze_time("2001-01-01 12:00:00.000000") def test_should_add_correct_calls_to_the_pipe(mocked_redis_client, mocked_redis_pipeline): mocked_redis_client.exceeded_rate_limit("key", 100, 100) @@ -377,3 +299,131 @@ def test_get_length_of_sorted_set_returns_none_if_not_active(self, better_mocked better_mocked_redis_client.active = False ret = better_mocked_redis_client.get_length_of_sorted_set("cache_key", min_score=0, max_score=100) assert ret == 0 + + +class TestRedisHashes: + @pytest.mark.parametrize( + "hash_key, fields_to_delete, expected_deleted, check_if_no_longer_exists", + [ + ("test:hash:key1", ["field1", "field2"], 2, False), # Delete specific fields in a hash + ("test:hash:key1", None, 3, True), # Delete All fields in a specific hash + ("test:hash:*", ["field1", "field2"], 6, False), # Delete specific fields in a group of hashes + ("test:hash:*", None, 9, True), # Delete All fields in a group of hashes + ], + ) + def test_delete_hash_fields( + self, + better_mocked_redis_client, + hash_key, + fields_to_delete, + expected_deleted, + check_if_no_longer_exists, + mocked_hash_structure, + ): + # set up the hashes to be deleted + for key, fields in mocked_hash_structure.items(): + better_mocked_redis_client.bulk_set_hash_fields(key=f"test:hash:{key}", mapping=fields) + + num_deleted = better_mocked_redis_client.delete_hash_fields(hashes=hash_key, fields=fields_to_delete) + + # Deleting all hash fields by pattern + if check_if_no_longer_exists and "*" in hash_key: + for key in mocked_hash_structure.keys(): + assert better_mocked_redis_client.redis_store.exists(f"test:hash:{key}") == 0 + # Deleting a specific hash + elif check_if_no_longer_exists: + assert better_mocked_redis_client.redis_store.exists(f"test:hash:{hash_key}") == 0 + + # Make sure we've deleted the correct number of fields + assert sum(num_deleted) == expected_deleted + + def test_get_hash_field(self, mocked_redis_client): + key = "12345" + field = "template-1111" + mocked_redis_client.redis_store.hget = Mock(return_value=b"8") + assert mocked_redis_client.get_hash_field(key, field) == b"8" + mocked_redis_client.redis_store.hget.assert_called_with(key, field) + + def test_set_hash_value(self, mocked_redis_client): + key = "12345" + field = "template-1111" + value = 8 + mocked_redis_client.set_hash_value(key, field, value) + mocked_redis_client.redis_store.hset.assert_called_with(key, field, value) + + @pytest.mark.parametrize( + "hash, updates, expected", + [ + ( + { + "key1": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key2": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + "key3": { + "field1": "value1", + "field2": 2, + "field3": "value3".encode("utf-8"), + }, + }, + { + "field1": "value2", + "field2": 3, + "field3": "value4".encode("utf-8"), + }, + { + b"field1": b"value2", + b"field2": b"3", + b"field3": b"value4", + }, + ) + ], + ) + def test_bulk_set_hash_fields(self, better_mocked_redis_client, hash, updates, expected): + for key, fields in hash.items(): + for field, value in fields.items(): + better_mocked_redis_client.set_hash_value(key, field, value) + + better_mocked_redis_client.bulk_set_hash_fields(pattern="key*", mapping=updates) + + for key, _ in hash.items(): + assert better_mocked_redis_client.redis_store.hgetall(key) == expected + + def test_decrement_hash_value_should_decrement_value_by_one_for_key(self, mocked_redis_client): + key = "12345" + value = "template-1111" + + mocked_redis_client.decrement_hash_value(key, value, -1) + mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, -1) + + def test_incr_hash_value_should_increment_value_by_one_for_key(self, mocked_redis_client): + key = "12345" + value = "template-1111" + + mocked_redis_client.increment_hash_value(key, value) + mocked_redis_client.redis_store.hincrby.assert_called_with(key, value, 1) + + def test_get_all_from_hash_returns_hash_for_key(self, mocked_redis_client): + key = "12345" + assert mocked_redis_client.get_all_from_hash(key) == {b"template-1111": b"8", b"template-2222": b"8"} + mocked_redis_client.redis_store.hgetall.assert_called_with(key) + + def test_set_hash_and_expire(self, mocked_redis_client): + key = "hash-key" + values = {"key": 10} + mocked_redis_client.set_hash_and_expire(key, values, 1) + mocked_redis_client.redis_store.hmset.assert_called_with(key, values) + mocked_redis_client.redis_store.expire.assert_called_with(key, 1) + + def test_set_hash_and_expire_converts_values_to_valid_types(self, mocked_redis_client): + key = "hash-key" + values = {uuid.UUID(int=0): 10} + mocked_redis_client.set_hash_and_expire(key, values, 1) + mocked_redis_client.redis_store.hmset.assert_called_with(key, {"00000000-0000-0000-0000-000000000000": 10}) + mocked_redis_client.redis_store.expire.assert_called_with(key, 1) From 3c896e96f123df38155296bc29bcdb90f06ef07c Mon Sep 17 00:00:00 2001 From: wbanks Date: Wed, 6 Nov 2024 15:21:43 -0500 Subject: [PATCH 5/8] Fix implicit returns --- notifications_utils/clients/redis/redis_client.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index c22d28e7..6ffe8888 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -119,8 +119,7 @@ def delete_hash_fields(self, hashes: (str | list), fields: list = None, raise_ex return result except Exception as e: self.__handle_exception(e, raise_exception, "expire_hash_fields", hashes) - else: - return False + return False def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception=False): """ @@ -135,9 +134,10 @@ def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception= for key in self.redis_store.scan_iter(pattern): self.redis_store.hmset(key, mapping) if key: - self.redis_store.hmset(key, mapping) + return self.redis_store.hmset(key, mapping) except Exception as e: self.__handle_exception(e, raise_exception, "bulk_set_hash_fields", pattern) + return False def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False): """ @@ -293,10 +293,12 @@ def set_hash_value(self, key, field, value, raise_exception=False): if self.active: try: - self.redis_store.hset(key, field, value) + return self.redis_store.hset(key, field, value) except Exception as e: self.__handle_exception(e, raise_exception, "set_hash_value", key) + return None + def decrement_hash_value(self, key, value, raise_exception=False): return self.increment_hash_value(key, value, raise_exception, incr_by=-1) @@ -331,6 +333,8 @@ def get_all_from_hash(self, key, raise_exception=False): except Exception as e: self.__handle_exception(e, raise_exception, "get_all_from_hash", key) + return None + def set_hash_and_expire(self, key, values, expire_in_seconds, raise_exception=False): key = prepare_value(key) values = {prepare_value(k): prepare_value(v) for k, v in values.items()} @@ -341,6 +345,8 @@ def set_hash_and_expire(self, key, values, expire_in_seconds, raise_exception=Fa except Exception as e: self.__handle_exception(e, raise_exception, "set_hash_and_expire", key) + return None + def expire(self, key, expire_in_seconds, raise_exception=False): key = prepare_value(key) if self.active: From e05b0ac93e74b297538d62f403732148fe5aee5c Mon Sep 17 00:00:00 2001 From: wbanks Date: Thu, 7 Nov 2024 10:41:49 -0500 Subject: [PATCH 6/8] Fix implicit optional, fix tests --- notifications_utils/clients/redis/redis_client.py | 4 ++-- tests/test_annual_limit.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index 6ffe8888..3f4831ab 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -1,7 +1,7 @@ import numbers import uuid from time import time -from typing import Any, Dict +from typing import Any, Dict, Optional from flask import current_app from flask_redis import FlaskRedis @@ -82,7 +82,7 @@ def delete_cache_keys_by_pattern(self, pattern): return 0 # TODO: Refactor and simplify this to use HEXPIRE when we upgrade Redis to 7.4.0 - def delete_hash_fields(self, hashes: (str | list), fields: list = None, raise_exception=False): + def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None, raise_exception=False): """Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely. Args: diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py index 9720f590..b11e6336 100644 --- a/tests/test_annual_limit.py +++ b/tests/test_annual_limit.py @@ -147,9 +147,7 @@ def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notificat mock_annual_limit_client.reset_all_notification_counts() for service_id in service_ids: - assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4 - for field in mock_notification_count_types: - assert mock_annual_limit_client.get_notification_count(service_id, field) == 0 + assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 0 def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id): From 089738dc75494e5302c81d06775c56f66a47ca28 Mon Sep 17 00:00:00 2001 From: wbanks Date: Thu, 7 Nov 2024 15:31:59 -0500 Subject: [PATCH 7/8] Allow dict values to be casted during decode --- notifications_utils/clients/redis/annual_limit.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 2f6b25d8..732e62e0 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -57,11 +57,14 @@ def annual_limit_status_key(service_id): return f"annual-limit:{service_id}:status" -def decode_byte_dict(dict: dict): +def decode_byte_dict(dict: dict, value_type=str): """ Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings. """ - return {key.decode("utf-8"): value.decode("utf-8") for key, value in dict.items()} + # Check if expected_value_type is one of the allowed types + if value_type not in {int, float, str}: + raise ValueError("expected_value_type must be int, float, or str") + return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items()} class RedisAnnualLimit: @@ -84,7 +87,7 @@ def get_all_notification_counts(self, service_id: str): """ Retrieves all daily notification metrics for a service. """ - return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id))) + return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int) def reset_all_notification_counts(self, service_ids=None): """ From c73808ef4e27eb39532bb58a0d814a80e441ebef Mon Sep 17 00:00:00 2001 From: wbanks Date: Thu, 7 Nov 2024 16:42:10 -0500 Subject: [PATCH 8/8] Add method to populate all notification count fields during seeding --- notifications_utils/clients/redis/annual_limit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 732e62e0..20bab04e 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -102,6 +102,9 @@ def reset_all_notification_counts(self, service_ids=None): self._redis_client.delete_hash_fields(hashes=hashes) + def seed_annual_limit_notifications(self, service_id: str, mapping: dict): + self._redis_client.bulk_set_hash_fields(key=annual_limit_notifications_key(service_id), mapping=mapping) + def was_seeded_today(self, service_id): last_seeded_time = self.get_seeded_at(service_id) return last_seeded_time == datetime.utcnow().strftime("%Y-%m-%d") if last_seeded_time else False