From acfde00d2c7bb9713afeb3e67e41cf8bea988e10 Mon Sep 17 00:00:00 2001 From: William B <7444334+whabanks@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:19:50 -0400 Subject: [PATCH] Fix seeding bugs add missing tests (#335) --- .github/actions/waffles/requirements.txt | 2 +- .../clients/redis/annual_limit.py | 5 ++-- pyproject.toml | 2 +- tests/test_annual_limit.py | 23 +++++++++++++++++++ 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/.github/actions/waffles/requirements.txt b/.github/actions/waffles/requirements.txt index 80b4baf0..90921392 100644 --- a/.github/actions/waffles/requirements.txt +++ b/.github/actions/waffles/requirements.txt @@ -1,4 +1,4 @@ docopt==0.6.2 Flask==2.3.3 markupsafe==2.1.5 -git+https://github.com/cds-snc/notifier-utils.git@52.3.7#egg=notifications-utils +git+https://github.com/cds-snc/notifier-utils.git@52.3.8#egg=notifications-utils diff --git a/notifications_utils/clients/redis/annual_limit.py b/notifications_utils/clients/redis/annual_limit.py index 20bab04e..823535a3 100644 --- a/notifications_utils/clients/redis/annual_limit.py +++ b/notifications_utils/clients/redis/annual_limit.py @@ -64,7 +64,7 @@ def decode_byte_dict(dict: dict, value_type=str): # Check if expected_value_type is one of the allowed types if value_type not in {int, float, str}: raise ValueError("expected_value_type must be int, float, or str") - return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items()} + return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items() if dict.items()} class RedisAnnualLimit: @@ -110,7 +110,8 @@ def was_seeded_today(self, service_id): return last_seeded_time == datetime.utcnow().strftime("%Y-%m-%d") if last_seeded_time else False def get_seeded_at(self, service_id: str): - return self._redis_client.get_hash_field(annual_limit_status_key(service_id), SEEDED_AT).decode("utf-8") + seeded_at = self._redis_client.get_hash_field(annual_limit_status_key(service_id), SEEDED_AT) + return seeded_at and seeded_at.decode("utf-8") def set_seeded_at(self, service_id): self._redis_client.set_hash_value(annual_limit_status_key(service_id), SEEDED_AT, datetime.utcnow().strftime("%Y-%m-%d")) diff --git a/pyproject.toml b/pyproject.toml index 4aa32641..fb1addce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "notifications-utils" -version = "52.3.7" +version = "52.3.8" description = "Shared python code for Notification - Provides logging utils etc." authors = ["Canadian Digital Service"] license = "MIT license" diff --git a/tests/test_annual_limit.py b/tests/test_annual_limit.py index b11e6336..04adc1b9 100644 --- a/tests/test_annual_limit.py +++ b/tests/test_annual_limit.py @@ -173,6 +173,29 @@ def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0 +@freeze_time("2024-10-25 12:00:00.000000") +@pytest.mark.parametrize("seeded_at_value, expected_value", [(b"2024-10-25", True), (None, False)]) +def test_was_seeded_today(mock_annual_limit_client, seeded_at_value, expected_value, mocked_service_id, mocker): + mocker.patch.object(mock_annual_limit_client._redis_client, "get_hash_field", return_value=seeded_at_value) + result = mock_annual_limit_client.was_seeded_today(mocked_service_id) + assert result == expected_value + + +@freeze_time("2024-10-25 12:00:00.000000") +def test_set_seeded_at(mock_annual_limit_client, mocked_service_id): + mock_annual_limit_client.set_seeded_at(mocked_service_id) + result = mock_annual_limit_client.get_seeded_at(mocked_service_id) + assert result == datetime.utcnow().strftime("%Y-%m-%d") + + +@freeze_time("2024-10-25 12:00:00.000000") +@pytest.mark.parametrize("seeded_at_value, expected_value", [(b"2024-10-25", "2024-10-25"), (None, None)]) +def test_get_seeded_at(mock_annual_limit_client, seeded_at_value, expected_value, mocked_service_id, mocker): + mocker.patch.object(mock_annual_limit_client._redis_client, "get_hash_field", return_value=seeded_at_value) + result = mock_annual_limit_client.get_seeded_at(mocked_service_id) + assert result == expected_value + + @freeze_time("2024-10-25 12:00:00.000000") def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id): mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id)