From 6bd84200753624d4f2d7fdd80b3b2b41491e3c34 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Dec 2024 01:08:34 +0100 Subject: [PATCH] feat: Refactor household endpoints to match new structure --- policyengine_api/routes/household_routes.py | 8 +- .../services/household_service.py | 5 + .../payload_validators/validate_country.py | 1 + tests/python/test_household.py | 408 ++++++++++++++++++ 4 files changed, 419 insertions(+), 3 deletions(-) create mode 100644 tests/python/test_household.py diff --git a/policyengine_api/routes/household_routes.py b/policyengine_api/routes/household_routes.py index 5953d5fa..2d6f54f5 100644 --- a/policyengine_api/routes/household_routes.py +++ b/policyengine_api/routes/household_routes.py @@ -14,8 +14,8 @@ household_service = HouseholdService() -@validate_country @household_bp.route("/", methods=["GET"]) +@validate_country def get_household(country_id: str, household_id: str) -> Response: """ Get a household's input data with a given ID. @@ -24,6 +24,7 @@ def get_household(country_id: str, household_id: str) -> Response: country_id (str): The country ID. household_id (str): The household ID. """ + print(f"Got request for household {household_id} in country {country_id}") # Ensure that household ID is a number try: @@ -38,6 +39,7 @@ def get_household(country_id: str, household_id: str) -> Response: household: dict | None = household_service.get_household( country_id, household_id ) + print(household) if household is None: return Response( json.dumps( @@ -71,8 +73,8 @@ def get_household(country_id: str, household_id: str) -> Response: ) -@validate_country @household_bp.route("", methods=["POST"]) +@validate_country def post_household(country_id: str) -> Response: """ Set a household's input data. @@ -127,8 +129,8 @@ def post_household(country_id: str) -> Response: ) -@validate_country @household_bp.route("/", methods=["PUT"]) +@validate_country def update_household(country_id: str, household_id: str) -> Response: """ Update a household's input data. diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index fab7a973..223045ee 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -24,11 +24,16 @@ def get_household(self, country_id: str, household_id: int) -> dict | None: (household_id, country_id), ).fetchone() + print(row) + print(type(row)) + # If row is present, we must JSON.loads the household_json household = None if row is not None: + print("Row is not None") household = dict(row) if household["household_json"]: + print("household_json is not None") household["household_json"] = json.loads( household["household_json"] ) diff --git a/policyengine_api/utils/payload_validators/validate_country.py b/policyengine_api/utils/payload_validators/validate_country.py index 3f1e0c5e..c891f98a 100644 --- a/policyengine_api/utils/payload_validators/validate_country.py +++ b/policyengine_api/utils/payload_validators/validate_country.py @@ -19,6 +19,7 @@ def validate_country(func): def validate_country_wrapper( country_id: str, *args, **kwargs ) -> Union[None, Response]: + print("Validating country") if country_id not in COUNTRIES: body = dict( status="error", diff --git a/tests/python/test_household.py b/tests/python/test_household.py new file mode 100644 index 00000000..6731864b --- /dev/null +++ b/tests/python/test_household.py @@ -0,0 +1,408 @@ +import pytest +from flask import Flask +import json +from unittest.mock import MagicMock, patch +from sqlalchemy.engine.row import LegacyRow + +from policyengine_api.routes.household_routes import household_bp +from policyengine_api.services.household_service import HouseholdService +from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS + +# TODO: Check if this format is correct +SAMPLE_HOUSEHOLD_DATA = { + "data": {"people": {"person1": {"age": 30, "income": 50000}}}, + "label": "Test Household", +} + +SAMPLE_DB_ROW = { + "id": 1, + "country_id": "us", + "household_json": json.dumps(SAMPLE_HOUSEHOLD_DATA["data"]), + "household_hash": "some-hash", + "label": "Test Household", + "api_version": "3.0.0", +} + + +# These will be moved to the correct location once +# testing PR that creates folder structure is merged +@pytest.fixture +def mock_database(): + """Mock the database module.""" + with patch( + "policyengine_api.services.household_service.database" + ) as mock_db: + yield mock_db + + +@pytest.fixture +def mock_hash_object(): + """Mock the hash_object function.""" + with patch( + "policyengine_api.services.household_service.hash_object" + ) as mock: + mock.return_value = "some-hash" + yield mock + + +class TestGetHousehold: + def test_get_existing_household(self, rest_client, mock_database): + """Test getting an existing household.""" + # Mock database response + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: SAMPLE_DB_ROW[x] + mock_row.keys.return_value = SAMPLE_DB_ROW.keys() + mock_database.query().fetchone.return_value = mock_row + + # Make request + response = rest_client.get("/us/household/1") + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + assert ( + data["result"]["household_json"] == SAMPLE_HOUSEHOLD_DATA["data"] + ) + + def test_get_nonexistent_household(self, rest_client, mock_database): + """Test getting a non-existent household.""" + mock_database.query().fetchone.return_value = None + + response = rest_client.get("/us/household/999") + data = json.loads(response.data) + + assert response.status_code == 404 + assert data["status"] == "error" + assert "not found" in data["message"] + + def test_get_household_invalid_id(self, rest_client): + """Test getting a household with invalid ID.""" + response = rest_client.get("/us/household/invalid") + + assert response.status_code == 400 + assert b"Invalid household ID" in response.data + + +class TestCreateHousehold: + def test_create_household_success( + self, rest_client, mock_database, mock_hash_object + ): + """Test successfully creating a new household.""" + # Mock database responses + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x] + mock_database.query().fetchone.return_value = mock_row + + response = rest_client.post( + "/us/household", + json=SAMPLE_HOUSEHOLD_DATA, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 201 + assert data["status"] == "ok" + assert data["result"]["household_id"] == 1 + + def test_create_household_invalid_payload(self, rest_client): + """Test creating a household with invalid payload.""" + invalid_payload = { + "label": "Test", + # Missing required 'data' field + } + + response = rest_client.post( + "/us/household", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Missing required keys" in response.data + + def test_create_household_invalid_label(self, rest_client): + """Test creating a household with invalid label type.""" + invalid_payload = { + "data": {}, + "label": 123, # Should be string or None + } + + response = rest_client.post( + "/us/household", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Label must be a string or None" in response.data + + +class TestUpdateHousehold: + def test_update_household_success( + self, rest_client, mock_database, mock_hash_object + ): + """Test successfully updating an existing household.""" + # Mock getting existing household + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: SAMPLE_DB_ROW[x] + mock_row.keys.return_value = SAMPLE_DB_ROW.keys() + mock_database.query().fetchone.return_value = mock_row + + updated_data = { + "data": {"people": {"person1": {"age": 31, "income": 55000}}}, + "label": "Updated Test Household", + } + + response = rest_client.put( + "/us/household/1", + json=updated_data, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["result"]["household_id"] == "1" + + def test_update_nonexistent_household(self, rest_client, mock_database): + """Test updating a non-existent household.""" + mock_database.query().fetchone.return_value = None + + response = rest_client.put( + "/us/household/999", + json=SAMPLE_HOUSEHOLD_DATA, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 404 + assert data["status"] == "error" + assert "not found" in data["message"] + + def test_update_household_invalid_payload(self, rest_client): + """Test updating a household with invalid payload.""" + invalid_payload = { + "label": "Test", + # Missing required 'data' field + } + + response = rest_client.put( + "/us/household/1", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Missing required keys" in response.data + + +# Service level tests +class TestHouseholdService: + def test_get_household(self, mock_database): + """Test HouseholdService.get_household method.""" + service = HouseholdService() + + # Mock database response + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: SAMPLE_DB_ROW[x] + mock_row.keys.return_value = SAMPLE_DB_ROW.keys() + mock_database.query().fetchone.return_value = mock_row + + result = service.get_household("us", 1) + + assert result is not None + assert result["household_json"] == SAMPLE_HOUSEHOLD_DATA["data"] + + def test_create_household(self, mock_database, mock_hash_object): + """Test HouseholdService.create_household method.""" + service = HouseholdService() + + # Mock database response for the ID query + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x] + mock_database.query().fetchone.return_value = mock_row + + household_id = service.create_household( + "us", SAMPLE_HOUSEHOLD_DATA["data"], SAMPLE_HOUSEHOLD_DATA["label"] + ) + + assert household_id == 1 + mock_database.query.assert_called() + + def test_update_household(self, mock_database, mock_hash_object): + """Test HouseholdService.update_household method.""" + service = HouseholdService() + + service.update_household( + "us", + "1", + SAMPLE_HOUSEHOLD_DATA["data"], + SAMPLE_HOUSEHOLD_DATA["label"], + ) + + mock_database.query.assert_called() + assert mock_hash_object.called + + +class TestHouseholdRouteValidation: + """Test validation and error handling in household routes.""" + + @pytest.mark.parametrize( + "invalid_payload", + [ + {}, # Empty payload + {"label": "Test"}, # Missing data field + {"data": None}, # None data + {"data": "not_a_dict"}, # Non-dict data + {"data": {}, "label": 123}, # Invalid label type + ], + ) + def test_post_household_invalid_payload( + self, rest_client, invalid_payload + ): + """Test POST endpoint with various invalid payloads.""" + response = rest_client.post( + "/us/household", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Unable to create new household" in response.data + + @pytest.mark.parametrize( + "invalid_id", + [ + "abc", # Non-numeric + "1.5", # Float + ], + ) + def test_get_household_invalid_id(self, rest_client, invalid_id): + """Test GET endpoint with invalid household IDs.""" + response = rest_client.get(f"/us/household/{invalid_id}") + print(response) + print(response.data) + + assert response.status_code == 400 + assert b"Invalid household ID" in response.data + + @pytest.mark.parametrize( + "country_id", + [ + "123", # Numeric + "us!!", # Special characters + "zz", # Non-ISO + "a" * 100, # Too long + ], + ) + def test_invalid_country_id(self, rest_client, country_id): + """Test endpoints with invalid country IDs.""" + # Test GET + get_response = rest_client.get(f"/{country_id}/household/1") + assert get_response.status_code == 400 + + # Test POST + post_response = rest_client.post( + f"/{country_id}/household", + json={"data": {}}, + content_type="application/json", + ) + assert post_response.status_code == 400 + + # Test PUT + put_response = rest_client.put( + f"/{country_id}/household/1", + json={"data": {}}, + content_type="application/json", + ) + assert put_response.status_code == 400 + + +class TestHouseholdRouteServiceErrors: + """Test handling of service-level errors in routes.""" + + @patch( + "policyengine_api.services.household_service.HouseholdService.get_household" + ) + def test_get_household_service_error(self, mock_get, rest_client): + """Test GET endpoint when service raises an error.""" + mock_get.side_effect = Exception("Database connection failed") + + response = rest_client.get("/us/household/1") + data = json.loads(response.data) + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Database connection failed" in data["message"] + + @patch( + "policyengine_api.services.household_service.HouseholdService.create_household" + ) + def test_post_household_service_error(self, mock_create, rest_client): + """Test POST endpoint when service raises an error.""" + mock_create.side_effect = Exception("Failed to create household") + + response = rest_client.post( + "/us/household", + json={"data": {"valid": "payload"}}, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Failed to create household" in data["message"] + + @patch( + "policyengine_api.services.household_service.HouseholdService.update_household" + ) + def test_put_household_service_error(self, mock_update, rest_client): + """Test PUT endpoint when service raises an error.""" + mock_update.side_effect = Exception("Failed to update household") + + # First mock the get_household call that checks existence + with patch( + "policyengine_api.services.household_service.HouseholdService.get_household" + ) as mock_get: + mock_get.return_value = {"id": 1} # Simulate existing household + + response = rest_client.put( + "/us/household/1", + json={"data": {"valid": "payload"}}, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Failed to update household" in data["message"] + + def test_missing_json_body(self, rest_client): + """Test endpoints when JSON body is missing.""" + # Test POST without JSON + post_response = rest_client.post("/us/household") + # Actually intercepted by server, which responds with 415, + # before we can even return a 400 + assert post_response.status_code in [400, 415] + + # Test PUT without JSON + put_response = rest_client.put("/us/household/1") + assert put_response.status_code in [400, 415] + + def test_malformed_json_body(self, rest_client): + """Test endpoints with malformed JSON body.""" + # Test POST with malformed JSON + post_response = rest_client.post( + "/us/household", + data="invalid json{", + content_type="application/json", + ) + assert post_response.status_code == 400 + + # Test PUT with malformed JSON + put_response = rest_client.put( + "/us/household/1", + data="invalid json{", + content_type="application/json", + ) + assert put_response.status_code == 400