diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 2f6b25d8..732e62e0 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -57,11 +57,14 @@ def annual_limit_status_key(service_id): return f"annual-limit:{service_id}:status" -def decode_byte_dict(dict: dict): +def decode_byte_dict(dict: dict, value_type=str): """ Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings. """ - return {key.decode("utf-8"): value.decode("utf-8") for key, value in dict.items()} + # Check if expected_value_type is one of the allowed types + if value_type not in {int, float, str}: + raise ValueError("expected_value_type must be int, float, or str") + return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items()} class RedisAnnualLimit: @@ -84,7 +87,7 @@ def get_all_notification_counts(self, service_id: str): """ Retrieves all daily notification metrics for a service. """ - return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id))) + return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int) def reset_all_notification_counts(self, service_ids=None): """