diff --git a/.github/actions/waffles/requirements.txt b/.github/actions/waffles/requirements.txt index 90921392..1309ae3e 100644 --- a/.github/actions/waffles/requirements.txt +++ b/.github/actions/waffles/requirements.txt @@ -1,4 +1,4 @@ docopt==0.6.2 Flask==2.3.3 markupsafe==2.1.5 -git+https://github.com/cds-snc/notifier-utils.git@52.3.8#egg=notifications-utils +git+https://github.com/cds-snc/notifier-utils.git@52.3.9#egg=notifications-utils diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 823535a3..d508d392 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -57,14 +57,16 @@ def annual_limit_status_key(service_id): return f"annual-limit:{service_id}:status" -def decode_byte_dict(dict: dict, value_type=str): +def decode_byte_dict(byte_dict: dict, value_type=str): """ Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings. """ # Check if expected_value_type is one of the allowed types if value_type not in {int, float, str}: raise ValueError("expected_value_type must be int, float, or str") - return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items() if dict.items()} + if byte_dict is None or not byte_dict.items(): + return None + return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()} class RedisAnnualLimit: @@ -75,13 +77,21 @@ def init_app(self, app, *args, **kwargs): pass def increment_notification_count(self, service_id: str, field: str): + """Increments the specified daily notification count field for a service. + Fields that can be set: `sms_delivered`, `email_delivered`, `sms_failed`, `email_failed` + + Args: + service_id (str): _description_ + field (str): _description_ + """ self._redis_client.increment_hash_value(annual_limit_notifications_key(service_id), field) def get_notification_count(self, service_id: str, field: str): """ Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.) """ - return int(self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field)) + count = self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field) + return count and int(count.decode("utf-8")) def get_all_notification_counts(self, service_id: str): """ @@ -90,9 +100,11 @@ def get_all_notification_counts(self, service_id: str): return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int) def reset_all_notification_counts(self, service_ids=None): - """ - Resets all daily notification metrics. - :param: service_ids: list of service_ids to reset, if None, resets all services + """Resets all daily notification metrics. + + Args: + service_ids (Optional): A list of service_ids to reset notification counts for. Resets all services if None. + """ hashes = ( annual_limit_notifications_key("*") @@ -103,6 +115,22 @@ def reset_all_notification_counts(self, service_ids=None): self._redis_client.delete_hash_fields(hashes=hashes) def seed_annual_limit_notifications(self, service_id: str, mapping: dict): + """Seeds annual limit notifications for a service. + + Args: + service_id (str): Service to seed annual limit notifications for. + mapping (dict): A dict used to map notification counts to their respective fields formatted as follows + + Examples: + `mapping` format: + + { + "sms_delivered": int, + "email_delivered": int, + "sms_failed": int, + "email_failed": int + } + """ self._redis_client.bulk_set_hash_fields(key=annual_limit_notifications_key(service_id), mapping=mapping) def was_seeded_today(self, service_id): @@ -124,18 +152,41 @@ def clear_notification_counts(self, service_id: str): def set_annual_limit_status(self, service_id: str, field: str, value: datetime): """ - Sets the status (e.g., 'nearing_limit', 'over_limit') in the annual limits Redis hash. + Sets the specified status field in the annual limits Redis hash for a service. + Fields that can be set: `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at` + + Args: + service_id (str): The service to set the annual limit status field for + field (str): The field to set in the annual limit status hash. + value (datetime): The date to set the status to """ self._redis_client.set_hash_value(annual_limit_status_key(service_id), field, value.strftime("%Y-%m-%d")) def get_annual_limit_status(self, service_id: str, field: str): """ Retrieves the value of a specific annual limit status from the Redis hash. + Fields that can be fetched: `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at` + + Args: + service_id (str): The service to fetch the annual limit status field for + field (str): The field to fetch from the annual limit status hash values: + `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at` + + Returns: + str | None: The date the status was set, or None if the status has not been set """ response = self._redis_client.get_hash_field(annual_limit_status_key(service_id), field) - return response.decode("utf-8") if response is not None else None + return response and response.decode("utf-8") def get_all_annual_limit_statuses(self, service_id: str): + """Retrieves all annual limit status fields for a specified service from Redis + + Args: + service_id (str): The service to fetch annual limit statuses for + + Returns: + dict | None: A dictionary of annual limit statuses or None if no statuses are found + """ return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id))) def clear_annual_limit_statuses(self, service_id: str): diff --git a/notifications_utils/clients/redis/redis_client.py b/notifications_utils/clients/redis/redis_client.py index 3f4831ab..c96ff22a 100644 --- a/notifications_utils/clients/redis/redis_client.py +++ b/notifications_utils/clients/redis/redis_client.py @@ -125,7 +125,7 @@ def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception= """ Bulk set hash fields. :param pattern: the pattern to match keys - :param mappting: the mapping of fields to set + :param mapping: the mapping of fields to set :param raise_exception: True if we should allow the exception to bubble up """ if self.active: diff --git a/pyproject.toml b/pyproject.toml index fb1addce..2c41ce17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "notifications-utils" -version = "52.3.8" +version = "52.3.9" description = "Shared python code for Notification - Provides logging utils etc." authors = ["Canadian Digital Service"] license = "MIT license" diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py index 04adc1b9..c32557c7 100644 --- a/tests/test_annual_limit.py +++ b/tests/test_annual_limit.py @@ -113,18 +113,26 @@ def test_get_notification_count(mock_annual_limit_client, mocked_service_id): assert result == 1 +def test_get_notification_count_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id): + assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) is None + + def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): for field in mock_notification_count_types: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 +def test_get_all_notification_counts_returns_none_if_fields_do_not_exist(mock_annual_limit_client, mocked_service_id): + assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None + + def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id): for field in mock_notification_count_types: mock_annual_limit_client.increment_notification_count(mocked_service_id, field) assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4 mock_annual_limit_client.clear_notification_counts(mocked_service_id) - assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0 + assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None @pytest.mark.parametrize( @@ -147,7 +155,7 @@ def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notificat mock_annual_limit_client.reset_all_notification_counts() for service_id in service_ids: - assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 0 + assert mock_annual_limit_client.get_all_notification_counts(service_id) is None def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id): @@ -164,13 +172,28 @@ def test_get_annual_limit_status(mock_annual_limit_client, mocked_service_id): assert result == near_limit_date.strftime("%Y-%m-%d") +def test_get_annual_limit_status_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id): + assert mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) is None + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_get_all_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id): + for status in mock_annual_limit_statuses: + mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow()) + assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4 + + +def test_get_all_annual_limit_statuses_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id): + assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None + + @freeze_time("2024-10-25 12:00:00.000000") def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id): for status in mock_annual_limit_statuses: mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow()) assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4 mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id) - assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0 + assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None @freeze_time("2024-10-25 12:00:00.000000") @@ -196,6 +219,10 @@ def test_get_seeded_at(mock_annual_limit_client, seeded_at_value, expected_value assert result == expected_value +def test_get_seeded_at_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id): + assert mock_annual_limit_client.get_seeded_at(mocked_service_id) is None + + @freeze_time("2024-10-25 12:00:00.000000") def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id): mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id)