Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify all gets to return None if no values present in Redis #336

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/waffles/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
docopt==0.6.2
Flask==2.3.3
markupsafe==2.1.5
git+https://github.com/cds-snc/[email protected].8#egg=notifications-utils
git+https://github.com/cds-snc/[email protected].9#egg=notifications-utils
67 changes: 59 additions & 8 deletions notifications_utils/clients/redis/annual_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ def annual_limit_status_key(service_id):
return f"annual-limit:{service_id}:status"


def decode_byte_dict(dict: dict, value_type=str):
def decode_byte_dict(byte_dict: dict, value_type=str):
"""
Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings.
"""
# Check if expected_value_type is one of the allowed types
if value_type not in {int, float, str}:
raise ValueError("expected_value_type must be int, float, or str")
return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items() if dict.items()}
if byte_dict is None or not byte_dict.items():
return None
return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()}


class RedisAnnualLimit:
Expand All @@ -75,13 +77,21 @@ def init_app(self, app, *args, **kwargs):
pass

def increment_notification_count(self, service_id: str, field: str):
"""Increments the specified daily notification count field for a service.
Fields that can be set: `sms_delivered`, `email_delivered`, `sms_failed`, `email_failed`

Args:
service_id (str): _description_
field (str): _description_
"""
self._redis_client.increment_hash_value(annual_limit_notifications_key(service_id), field)

def get_notification_count(self, service_id: str, field: str):
"""
Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.)
"""
return int(self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field))
count = self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field)
return count and int(count.decode("utf-8"))

def get_all_notification_counts(self, service_id: str):
"""
Expand All @@ -90,9 +100,11 @@ def get_all_notification_counts(self, service_id: str):
return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int)

def reset_all_notification_counts(self, service_ids=None):
"""
Resets all daily notification metrics.
:param: service_ids: list of service_ids to reset, if None, resets all services
"""Resets all daily notification metrics.

Args:
service_ids (Optional): A list of service_ids to reset notification counts for. Resets all services if None.

"""
hashes = (
annual_limit_notifications_key("*")
Expand All @@ -103,6 +115,22 @@ def reset_all_notification_counts(self, service_ids=None):
self._redis_client.delete_hash_fields(hashes=hashes)

def seed_annual_limit_notifications(self, service_id: str, mapping: dict):
"""Seeds annual limit notifications for a service.

Args:
service_id (str): Service to seed annual limit notifications for.
mapping (dict): A dict used to map notification counts to their respective fields formatted as follows

Examples:
`mapping` format:

{
"sms_delivered": int,
"email_delivered": int,
"sms_failed": int,
"email_failed": int
}
"""
self._redis_client.bulk_set_hash_fields(key=annual_limit_notifications_key(service_id), mapping=mapping)

def was_seeded_today(self, service_id):
Expand All @@ -124,18 +152,41 @@ def clear_notification_counts(self, service_id: str):

def set_annual_limit_status(self, service_id: str, field: str, value: datetime):
"""
Sets the status (e.g., 'nearing_limit', 'over_limit') in the annual limits Redis hash.
Sets the specified status field in the annual limits Redis hash for a service.
Fields that can be set: `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at`

Args:
service_id (str): The service to set the annual limit status field for
field (str): The field to set in the annual limit status hash.
value (datetime): The date to set the status to
"""
self._redis_client.set_hash_value(annual_limit_status_key(service_id), field, value.strftime("%Y-%m-%d"))

def get_annual_limit_status(self, service_id: str, field: str):
"""
Retrieves the value of a specific annual limit status from the Redis hash.
Fields that can be fetched: `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at`

Args:
service_id (str): The service to fetch the annual limit status field for
field (str): The field to fetch from the annual limit status hash values:
`near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at`

Returns:
str | None: The date the status was set, or None if the status has not been set
"""
response = self._redis_client.get_hash_field(annual_limit_status_key(service_id), field)
return response.decode("utf-8") if response is not None else None
return response and response.decode("utf-8")

def get_all_annual_limit_statuses(self, service_id: str):
"""Retrieves all annual limit status fields for a specified service from Redis

Args:
service_id (str): The service to fetch annual limit statuses for

Returns:
dict | None: A dictionary of annual limit statuses or None if no statuses are found
"""
return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id)))

def clear_annual_limit_statuses(self, service_id: str):
Expand Down
2 changes: 1 addition & 1 deletion notifications_utils/clients/redis/redis_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception=
"""
Bulk set hash fields.
:param pattern: the pattern to match keys
:param mappting: the mapping of fields to set
:param mapptnig: the mapping of fields to set
whabanks marked this conversation as resolved.
Show resolved Hide resolved
:param raise_exception: True if we should allow the exception to bubble up
"""
if self.active:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "notifications-utils"
version = "52.3.8"
version = "52.3.9"
description = "Shared python code for Notification - Provides logging utils etc."
authors = ["Canadian Digital Service"]
license = "MIT license"
Expand Down
33 changes: 30 additions & 3 deletions tests/test_annual_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,26 @@ def test_get_notification_count(mock_annual_limit_client, mocked_service_id):
assert result == 1


def test_get_notification_count_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) is None


def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4


def test_get_all_notification_counts_returns_none_if_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None


def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4
mock_annual_limit_client.clear_notification_counts(mocked_service_id)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0
assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None


@pytest.mark.parametrize(
Expand All @@ -147,7 +155,7 @@ def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notificat
mock_annual_limit_client.reset_all_notification_counts()

for service_id in service_ids:
assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 0
assert mock_annual_limit_client.get_all_notification_counts(service_id) is None


def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id):
Expand All @@ -164,13 +172,28 @@ def test_get_annual_limit_status(mock_annual_limit_client, mocked_service_id):
assert result == near_limit_date.strftime("%Y-%m-%d")


def test_get_annual_limit_status_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) is None


@freeze_time("2024-10-25 12:00:00.000000")
def test_get_all_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
for status in mock_annual_limit_statuses:
mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4


def test_get_all_annual_limit_statuses_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None


@freeze_time("2024-10-25 12:00:00.000000")
def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
for status in mock_annual_limit_statuses:
mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4
mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id)
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0
assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None


@freeze_time("2024-10-25 12:00:00.000000")
Expand All @@ -196,6 +219,10 @@ def test_get_seeded_at(mock_annual_limit_client, seeded_at_value, expected_value
assert result == expected_value


def test_get_seeded_at_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_seeded_at(mocked_service_id) is None


@freeze_time("2024-10-25 12:00:00.000000")
def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id):
mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id)
Expand Down
Loading