Skip to content

Commit

Permalink
Fix bug in delete_hash_fields (#337)
Browse files Browse the repository at this point in the history
- delete_hash_fields was occasionally missing fields when used to clear
annual limit notification counts
- Refactored annual limit client methods that returned numeric data to
  return 0 instead of None when no data was found in the queried hash
  • Loading branch information
whabanks authored Nov 21, 2024
1 parent b344e5a commit 25ef1ae
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .github/actions/waffles/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
docopt==0.6.2
Flask==2.3.3
markupsafe==2.1.5
git+https://github.com/cds-snc/notifier-utils.git@52.3.9#egg=notifications-utils
git+https://github.com/cds-snc/notifier-utils.git@52.4.0#egg=notifications-utils
64 changes: 54 additions & 10 deletions notifications_utils/clients/redis/annual_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
SMS_FAILED = "sms_failed"
EMAIL_FAILED = "email_failed"

NOTIFICATIONS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED]
NOTIFICATION_FIELDS = [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED]

NEAR_SMS_LIMIT = "near_sms_limit"
NEAR_EMAIL_LIMIT = "near_email_limit"
OVER_SMS_LIMIT = "over_sms_limit"
OVER_EMAIL_LIMIT = "over_email_limit"
SEEDED_AT = "seeded_at"

STATUSES = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT]
STATUS_FIELDS = [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT]


def annual_limit_notifications_key(service_id):
Expand All @@ -57,16 +57,42 @@ def annual_limit_status_key(service_id):
return f"annual-limit:{service_id}:status"


def decode_byte_dict(byte_dict: dict, value_type=str):
def prepare_byte_dict(byte_dict: dict, value_type=str, required_keys=None):
"""
Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings.
"""
# Check if expected_value_type is one of the allowed types
if value_type not in {int, float, str}:
raise ValueError("expected_value_type must be int, float, or str")
if byte_dict is None or not byte_dict.items():
return None
return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()}

decoded_dict = (
{key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()} if byte_dict else {}
)

if required_keys:
for key in required_keys:
default_value = 0 if value_type in {int, float} else None
decoded_dict.setdefault(key, default_value)
return decoded_dict


def init_missing_keys(
required_keys: list,
value_type=str,
incomplete_dict: dict = {},
):
"""Ensures that all expected keys are present in dicts returned from this module. Initializes empty values to defaults if not.
Args:
incomplete_dict (dict): A dictionary to check for required keys.
required_keys (list): The keys that must be present in the dictionary.
value_type (_type_, optional): The datatype of the values in the dict. Defaults to str.
Raises:
ValueError: If the value_type is not int, float, or str.
"""
if value_type not in {int, float, str}:
raise ValueError("expected_value_type must be int, float, or str")


class RedisAnnualLimit:
Expand All @@ -91,13 +117,15 @@ def get_notification_count(self, service_id: str, field: str):
Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.)
"""
count = self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field)
return count and int(count.decode("utf-8"))
return 0 if not count else int(count.decode("utf-8"))

def get_all_notification_counts(self, service_id: str):
"""
Retrieves all daily notification metrics for a service.
"""
return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int)
return prepare_byte_dict(
self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int, NOTIFICATION_FIELDS
)

def reset_all_notification_counts(self, service_ids=None):
"""Resets all daily notification metrics.
Expand All @@ -112,7 +140,7 @@ def reset_all_notification_counts(self, service_ids=None):
else [annual_limit_notifications_key(service_id) for service_id in service_ids]
)

self._redis_client.delete_hash_fields(hashes=hashes)
self._redis_client.delete_hash_fields(hashes=hashes, fields=NOTIFICATION_FIELDS)

def seed_annual_limit_notifications(self, service_id: str, mapping: dict):
"""Seeds annual limit notifications for a service.
Expand Down Expand Up @@ -187,7 +215,7 @@ def get_all_annual_limit_statuses(self, service_id: str):
Returns:
dict | None: A dictionary of annual limit statuses or None if no statuses are found
"""
return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id)))
return prepare_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id)), str, STATUS_FIELDS)

def clear_annual_limit_statuses(self, service_id: str):
self._redis_client.expire(f"{annual_limit_status_key(service_id)}", -1)
Expand Down Expand Up @@ -237,3 +265,19 @@ def check_has_over_limit_been_sent(self, service_id: str, message_type: str):
"""
field_to_fetch = OVER_SMS_LIMIT if message_type == "sms" else OVER_EMAIL_LIMIT
return self.get_annual_limit_status(service_id, field_to_fetch)

def delete_all_annual_limit_hashes(self, service_ids=None):
"""
THIS SHOULD NOT BE CALLED IN CODE. This is a helper method for testing purposes only.
Clears all annual limit hashes in Redis
Args:
service_ids (Optional): A list of service_ids to clear annual limit hashes for. Clears all services if None.
"""
if not service_ids:
self._redis_client.delete_cache_keys_by_pattern(annual_limit_notifications_key("*"))
self._redis_client.delete_cache_keys_by_pattern(annual_limit_status_key("*"))
else:
for service_id in service_ids:
self._redis_client.delete(annual_limit_notifications_key(service_id))
self._redis_client.delete(annual_limit_status_key(service_id))
12 changes: 3 additions & 9 deletions notifications_utils/clients/redis/redis_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numbers
import uuid
from time import time
from typing import Any, Dict, Optional
from typing import Any, Dict

from flask import current_app
from flask_redis import FlaskRedis
Expand Down Expand Up @@ -82,7 +82,7 @@ def delete_cache_keys_by_pattern(self, pattern):
return 0

# TODO: Refactor and simplify this to use HEXPIRE when we upgrade Redis to 7.4.0
def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None, raise_exception=False):
def delete_hash_fields(self, hashes: (str | list), fields: list, raise_exception=False):
"""Deletes fields from the specified hashes. if fields is `None`, then all fields from the hashes are deleted, deleting the hash entirely.
Args:
Expand All @@ -98,13 +98,7 @@ def delete_hash_fields(self, hashes: (str | list), fields: Optional[list] = None
# When fields are passed in, use the list as is
# When hashes is a list, and no fields are passed in, fetch the fields from the first hash in the list
# otherwise we know we're going scan iterate over a pattern so we'll fetch the fields on the first pass in the loop below
fields = (
[prepare_value(f) for f in fields]
if fields is not None
else self.redis_store.hkeys(hashes[0])
if isinstance(hashes, list)
else None
)
fields = [prepare_value(f) for f in fields]
# Use a pipeline to atomically delete fields from each hash.
pipe = self.redis_store.pipeline()
# if hashes is not a list, we're scan iterating over keys matching a pattern
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "notifications-utils"
version = "52.3.9"
version = "52.4.0"
description = "Shared python code for Notification - Provides logging utils etc."
authors = ["Canadian Digital Service"]
license = "MIT license"
Expand Down
101 changes: 58 additions & 43 deletions tests/test_annual_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,19 @@
EMAIL_FAILED,
NEAR_EMAIL_LIMIT,
NEAR_SMS_LIMIT,
NOTIFICATION_FIELDS,
OVER_EMAIL_LIMIT,
OVER_SMS_LIMIT,
SMS_DELIVERED,
SMS_FAILED,
STATUS_FIELDS,
RedisAnnualLimit,
annual_limit_notifications_key,
annual_limit_status_key,
)
from notifications_utils.clients.redis.redis_client import RedisClient


@pytest.fixture(scope="function")
def mock_notification_count_types():
return [SMS_DELIVERED, EMAIL_DELIVERED, SMS_FAILED, EMAIL_FAILED]


@pytest.fixture(scope="function")
def mock_annual_limit_statuses():
return [NEAR_SMS_LIMIT, NEAR_EMAIL_LIMIT, OVER_SMS_LIMIT, OVER_EMAIL_LIMIT]


@pytest.fixture(scope="function")
def mocked_redis_pipeline():
return Mock()
Expand Down Expand Up @@ -114,25 +106,32 @@ def test_get_notification_count(mock_annual_limit_client, mocked_service_id):


def test_get_notification_count_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) is None
assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 0


def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
def test_get_all_notification_counts(mock_annual_limit_client, mocked_service_id):
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4
counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id)
assert len(counts) == 4
assert all(isinstance(value, int) for value in counts.values())


def test_get_all_notification_counts_returns_none_if_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None
notification_counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id)
assert set(notification_counts.keys()) == set(NOTIFICATION_FIELDS)
assert all(value == 0 for value in notification_counts.values())


def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
def test_clear_notification_counts(mock_annual_limit_client, mocked_service_id):
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4

mock_annual_limit_client.clear_notification_counts(mocked_service_id)
assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None
counts = mock_annual_limit_client.get_all_notification_counts(mocked_service_id)
assert set(counts.keys()) == set(NOTIFICATION_FIELDS)
assert all(value == 0 for value in counts.values())


@pytest.mark.parametrize(
Expand All @@ -146,16 +145,20 @@ def test_clear_notification_counts(mock_annual_limit_client, mock_notification_c
]
],
)
def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notification_count_types, service_ids):
def test_bulk_reset_notification_counts(mock_annual_limit_client, service_ids):
for service_id in service_ids:
for field in mock_notification_count_types:
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 4

mock_annual_limit_client.reset_all_notification_counts()
counts = mock_annual_limit_client.get_all_notification_counts(service_id)
assert set(counts.keys()) == set(NOTIFICATION_FIELDS)
assert all(value > 0 for value in counts.values())

mock_annual_limit_client.reset_all_notification_counts()
for service_id in service_ids:
assert mock_annual_limit_client.get_all_notification_counts(service_id) is None
counts = mock_annual_limit_client.get_all_notification_counts(service_id)
assert set(counts.keys()) == set(NOTIFICATION_FIELDS)
assert all(value == 0 for value in counts.values())


def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id):
Expand All @@ -177,23 +180,35 @@ def test_get_annual_limit_status_returns_none_when_fields_do_not_exist(mock_annu


@freeze_time("2024-10-25 12:00:00.000000")
def test_get_all_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
for status in mock_annual_limit_statuses:
def test_get_all_annual_limit_statuses(mock_annual_limit_client, mocked_service_id):
for status in STATUS_FIELDS:
mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4

statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)
assert len(statuses) == 4
assert all(value is not None for value in statuses.values())


def test_get_all_annual_limit_statuses_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None
statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)
assert set(statuses.keys()) == set(STATUS_FIELDS)
assert all(value is None for value in statuses.values())


@freeze_time("2024-10-25 12:00:00.000000")
def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
for status in mock_annual_limit_statuses:
def test_clear_annual_limit_statuses(mock_annual_limit_client, mocked_service_id):
for status in STATUS_FIELDS:
mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4

statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)
assert len(statuses) == 4
assert all(value == "2024-10-25" for value in statuses.values())

mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id)
assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None

statuses = mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)
assert set(statuses.keys()) == set(STATUS_FIELDS)
assert all(value is None for value in statuses.values())


@freeze_time("2024-10-25 12:00:00.000000")
Expand Down Expand Up @@ -251,50 +266,50 @@ def test_set_over_email_limit(mock_annual_limit_client, mocked_service_id):
assert result == datetime.utcnow().strftime("%Y-%m-%d")


def test_increment_sms_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
def test_increment_sms_delivered(mock_annual_limit_client, mocked_service_id):
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)

mock_annual_limit_client.increment_sms_delivered(mocked_service_id)

assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) == 2
for field in mock_notification_count_types:
for field in NOTIFICATION_FIELDS:
if field != SMS_DELIVERED:
assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1


def test_increment_sms_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
def test_increment_sms_failed(mock_annual_limit_client, mocked_service_id):
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)

mock_annual_limit_client.increment_sms_failed(mocked_service_id)

assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_FAILED) == 2
for field in mock_notification_count_types:
for field in NOTIFICATION_FIELDS:
if field != SMS_FAILED:
assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1


def test_increment_email_delivered(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
def test_increment_email_delivered(mock_annual_limit_client, mocked_service_id):
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)

mock_annual_limit_client.increment_email_delivered(mocked_service_id)

assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_DELIVERED) == 2
for field in mock_notification_count_types:
for field in NOTIFICATION_FIELDS:
if field != EMAIL_DELIVERED:
assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1


def test_increment_email_failed(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
def test_increment_email_failed(mock_annual_limit_client, mocked_service_id):
for field in NOTIFICATION_FIELDS:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)

mock_annual_limit_client.increment_email_failed(mocked_service_id)

assert mock_annual_limit_client.get_notification_count(mocked_service_id, EMAIL_FAILED) == 2
for field in mock_notification_count_types:
for field in NOTIFICATION_FIELDS:
if field != EMAIL_FAILED:
assert mock_annual_limit_client.get_notification_count(mocked_service_id, field) == 1

Expand Down
2 changes: 0 additions & 2 deletions tests/test_redis_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,7 @@ class TestRedisHashes:
"hash_key, fields_to_delete, expected_deleted, check_if_no_longer_exists",
[
("test:hash:key1", ["field1", "field2"], 2, False), # Delete specific fields in a hash
("test:hash:key1", None, 3, True), # Delete All fields in a specific hash
("test:hash:*", ["field1", "field2"], 6, False), # Delete specific fields in a group of hashes
("test:hash:*", None, 9, True), # Delete All fields in a group of hashes
],
)
def test_delete_hash_fields(
Expand Down

0 comments on commit 25ef1ae

Please sign in to comment.