diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..81463cefe 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + changed + - validate_country to decorator for validation on applicable functions diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 10ddca6e5..b9db387d6 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,5 +1,6 @@ import importlib from flask import Response +from functools import wraps import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem from typing import Union, Optional @@ -497,19 +498,26 @@ def get_requested_computations(household: dict): } -def validate_country(country_id: str) -> Union[None, Response]: +def validate_country(func): """Validate that a country ID is valid. If not, return a 404 response. Args: country_id (str): The country ID to validate. Returns: - + Response(404) if country is not valid, else continues """ - if country_id not in COUNTRIES: - body = dict( - status="error", - message=f"Country {country_id} not found. Available countries are: {', '.join(COUNTRIES.keys())}", - ) - return Response(json.dumps(body), status=404) - return None + + @wraps(func) + def validate_country_wrapper( + country_id: str, *args, **kwargs + ) -> Union[None, Response]: + if country_id not in COUNTRIES: + body = dict( + status="error", + message=f"Country {country_id} not found. Available countries are: {', '.join(COUNTRIES.keys())}", + ) + return Response(json.dumps(body), status=404) + return func(country_id, *args, **kwargs) + + return validate_country_wrapper diff --git a/policyengine_api/endpoints/economy/economy.py b/policyengine_api/endpoints/economy/economy.py index ed62eac2b..f1531beaf 100644 --- a/policyengine_api/endpoints/economy/economy.py +++ b/policyengine_api/endpoints/economy/economy.py @@ -36,6 +36,7 @@ def get_average_time(): return total_time / len(recent_jobs) +@validate_country def get_economic_impact( country_id: str, policy_id: str, baseline_policy_id: str = None ): @@ -50,9 +51,6 @@ def get_economic_impact( dict: The economic impact. """ print(f"Got request for {country_id} {policy_id} {baseline_policy_id}") - invalid_country = validate_country(country_id) - if invalid_country: - return invalid_country policy_id = int(policy_id or get_current_law_policy_id(country_id)) baseline_policy_id = int( diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index 72832a5c9..42574dafb 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -82,6 +82,7 @@ def get_household_year(household): return household_year +@validate_country def get_household(country_id: str, household_id: str) -> dict: """Get a household's input data with a given ID. @@ -89,12 +90,8 @@ def get_household(country_id: str, household_id: str) -> dict: country_id (str): The country ID. household_id (str): The household ID. """ - invalid_country = validate_country(country_id) - if invalid_country: - return invalid_country # Retrieve from the household table - row = database.query( f"SELECT * FROM household WHERE id = ? AND country_id = ?", (household_id, country_id), @@ -120,15 +117,13 @@ def get_household(country_id: str, household_id: str) -> dict: ) +@validate_country def post_household(country_id: str) -> dict: """Set a household's input data. Args: country_id (str): The country ID. """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found payload = request.json label = payload.get("label") @@ -169,6 +164,7 @@ def post_household(country_id: str) -> dict: ) +@validate_country def update_household(country_id: str, household_id: str) -> Response: """ Update a household via UPDATE request @@ -176,10 +172,6 @@ def update_household(country_id: str, household_id: str) -> Response: Args: country_id (str): The country ID """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - # Fetch existing household first try: row = database.query( @@ -258,6 +250,7 @@ def update_household(country_id: str, household_id: str) -> Response: ) +@validate_country def get_household_under_policy( country_id: str, household_id: str, policy_id: str ): @@ -268,9 +261,6 @@ def get_household_under_policy( household_id (str): The household ID. policy_id (str): The policy ID. """ - invalid_country = validate_country(country_id) - if invalid_country: - return invalid_country api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) @@ -393,6 +383,7 @@ def get_household_under_policy( ) +@validate_country def get_calculate(country_id: str, add_missing: bool = False) -> dict: """Lightweight endpoint for passing in household and policy JSON objects and calculating without storing data. @@ -400,10 +391,6 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict: country_id (str): The country ID. """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - payload = request.json household_json = payload.get("household", {}) policy_json = payload.get("policy", {}) diff --git a/policyengine_api/endpoints/metadata.py b/policyengine_api/endpoints/metadata.py index b1b9f1511..91491c8bb 100644 --- a/policyengine_api/endpoints/metadata.py +++ b/policyengine_api/endpoints/metadata.py @@ -1,14 +1,11 @@ from policyengine_api.country import COUNTRIES, validate_country +@validate_country def get_metadata(country_id: str) -> dict: """Get metadata for a country. Args: country_id (str): The country ID. """ - invalid_country = validate_country(country_id) - if invalid_country: - return invalid_country - return COUNTRIES.get(country_id).metadata diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index 4059068e0..f8e5fcacf 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -10,6 +10,7 @@ import sqlalchemy.exc +@validate_country def get_policy(country_id: str, policy_id: int) -> dict: """ Get policy data for a given country and policy ID. @@ -21,9 +22,7 @@ def get_policy(country_id: str, policy_id: int) -> dict: Returns: dict: The policy record. """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found + # Get the policy record for a given policy ID. row = database.query( f"SELECT * FROM policy WHERE country_id = ? AND id = ?", @@ -48,6 +47,7 @@ def get_policy(country_id: str, policy_id: int) -> dict: ) +@validate_country def set_policy( country_id: str, ) -> dict: @@ -59,9 +59,6 @@ def set_policy( Args: country_id (str): The country ID. """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found payload = request.json label = payload.pop("label", None) @@ -176,6 +173,7 @@ def set_policy( ) +@validate_country def get_policy_search(country_id: str) -> dict: """ Search for policies for a specified country @@ -195,6 +193,7 @@ def get_policy_search(country_id: str) -> dict: Example: GET /api/policies/us?query=tax&unique_only=true """ + query = request.args.get("query", "") # The "json.loads" default type is added to convert lowercase # "true" and "false" to Python-friendly bool values @@ -202,10 +201,6 @@ def get_policy_search(country_id: str) -> dict: "unique_only", default=False, type=json.loads ) - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - try: results = database.query( "SELECT id, label, policy_hash FROM policy WHERE country_id = ? AND label LIKE ?", @@ -269,6 +264,7 @@ def get_current_law_policy_id(country_id: str) -> int: }[country_id] +@validate_country def set_user_policy(country_id: str) -> dict: """ Adds a record (if unique, barring type) to the user_policy table @@ -277,10 +273,6 @@ def set_user_policy(country_id: str) -> dict: is currently unused """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - payload = request.json reform_label = payload.pop("reform_label", None) reform_id = payload.pop("reform_id") @@ -436,14 +428,12 @@ def set_user_policy(country_id: str) -> dict: ) +@validate_country def get_user_policy(country_id: str, user_id: str) -> dict: """ Fetch all saved user policies by user id """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found # Get the policy record for a given policy ID. rows = database.query( f"SELECT * FROM user_policies WHERE country_id = ? AND user_id = ?", @@ -488,15 +478,12 @@ def get_user_policy(country_id: str, user_id: str) -> dict: ) +@validate_country def update_user_policy(country_id: str) -> dict: """ Update any parts of a user_policy, given a user_policy ID """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - # Construct the relevant UPDATE request setter_array = [] args = [] diff --git a/policyengine_api/endpoints/tracer_analysis.py b/policyengine_api/endpoints/tracer_analysis.py index a42fd8e45..7aa0ba016 100644 --- a/policyengine_api/endpoints/tracer_analysis.py +++ b/policyengine_api/endpoints/tracer_analysis.py @@ -20,6 +20,7 @@ # TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit +@validate_country def execute_tracer_analysis( country_id: str, ): @@ -29,10 +30,6 @@ def execute_tracer_analysis( country_id (str): The country ID. """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - payload = request.json household_id = payload.get("household_id") diff --git a/policyengine_api/endpoints/user_profile.py b/policyengine_api/endpoints/user_profile.py index c3c877a1f..259e4727f 100644 --- a/policyengine_api/endpoints/user_profile.py +++ b/policyengine_api/endpoints/user_profile.py @@ -4,13 +4,11 @@ import json +@validate_country def set_user_profile(country_id: str) -> dict: """ Creates a new user_profile """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found payload = request.json primary_country = country_id @@ -86,6 +84,7 @@ def set_user_profile(country_id: str) -> dict: ) +@validate_country def get_user_profile(country_id: str) -> dict: """ Get a user profile in one of two ways: by auth0_id, @@ -93,10 +92,6 @@ def get_user_profile(country_id: str) -> dict: all data except auth0_id """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - if len(request.args) != 1: return Response( json.dumps( @@ -175,6 +170,7 @@ def get_user_profile(country_id: str) -> dict: ) +@validate_country def update_user_profile(country_id: str) -> dict: """ Update any part of a user_profile, given a user_id, @@ -182,10 +178,6 @@ def update_user_profile(country_id: str) -> dict: will assume malicious intent and 403 """ - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - # Construct the relevant UPDATE request setter_array = [] args = [] diff --git a/tests/python/test_country.py b/tests/python/test_country.py new file mode 100644 index 000000000..f2a279647 --- /dev/null +++ b/tests/python/test_country.py @@ -0,0 +1,27 @@ +from flask import Response +from policyengine_api.country import validate_country + + +@validate_country +def foo(country_id, other): + """ + A simple dummy test method for validation testing. Must be defined outside of the class (or within + the test functions themselves) due to complications with the `self` parameter for class methods. + """ + return "bar" + +class TestValidateCountry: + """ + Test that the @validate_country decorator returns 404 if the country does not exist, otherwise + continues execution of the function. + """ + + def test_valid_country(self): + result = foo("us", "extra_arg") + assert result == "bar" + + def test_invalid_country(self): + result = foo("baz", "extra_arg") + assert isinstance(result, Response) + assert result.status_code == 404 + \ No newline at end of file diff --git a/tests/python/test_policy.py b/tests/python/test_policy.py index 16dc5e570..802ef1c3e 100644 --- a/tests/python/test_policy.py +++ b/tests/python/test_policy.py @@ -44,6 +44,11 @@ def test_create_nonunique_policy(self, rest_client): (self.policy_hash, self.label, self.country_id), ) + def test_create_policy_invalid_country(self, rest_client): + res = rest_client.post("/au/policy", json=self.test_policy) + assert res.status_code == 404 + + class TestPolicySearch: country_id = "us"