Skip to content

Commit

Permalink
Switch @validate_country to decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
ldorner authored and ldorner committed Nov 12, 2024
1 parent 1e42f63 commit f4ea3f7
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 70 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
changed
- validate_country to decorator for validation on applicable functions
26 changes: 17 additions & 9 deletions policyengine_api/country.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
from flask import Response
from functools import wraps
import json
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
from typing import Union, Optional
Expand Down Expand Up @@ -497,19 +498,26 @@ def get_requested_computations(household: dict):
}


def validate_country(country_id: str) -> Union[None, Response]:
def validate_country(func):
"""Validate that a country ID is valid. If not, return a 404 response.
Args:
country_id (str): The country ID to validate.
Returns:
Response(404) if country is not valid, else continues
"""
if country_id not in COUNTRIES:
body = dict(
status="error",
message=f"Country {country_id} not found. Available countries are: {', '.join(COUNTRIES.keys())}",
)
return Response(json.dumps(body), status=404)
return None

@wraps(func)
def validate_country_wrapper(
country_id: str, *args, **kwargs
) -> Union[None, Response]:
if country_id not in COUNTRIES:
body = dict(
status="error",
message=f"Country {country_id} not found. Available countries are: {', '.join(COUNTRIES.keys())}",
)
return Response(json.dumps(body), status=404)
return func(country_id, *args, **kwargs)

return validate_country_wrapper
4 changes: 1 addition & 3 deletions policyengine_api/endpoints/economy/economy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_average_time():
return total_time / len(recent_jobs)


@validate_country
def get_economic_impact(
country_id: str, policy_id: str, baseline_policy_id: str = None
):
Expand All @@ -50,9 +51,6 @@ def get_economic_impact(
dict: The economic impact.
"""
print(f"Got request for {country_id} {policy_id} {baseline_policy_id}")
invalid_country = validate_country(country_id)
if invalid_country:
return invalid_country

policy_id = int(policy_id or get_current_law_policy_id(country_id))
baseline_policy_id = int(
Expand Down
23 changes: 5 additions & 18 deletions policyengine_api/endpoints/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,16 @@ def get_household_year(household):
return household_year


@validate_country
def get_household(country_id: str, household_id: str) -> dict:
"""Get a household's input data with a given ID.
Args:
country_id (str): The country ID.
household_id (str): The household ID.
"""
invalid_country = validate_country(country_id)
if invalid_country:
return invalid_country

# Retrieve from the household table

row = database.query(
f"SELECT * FROM household WHERE id = ? AND country_id = ?",
(household_id, country_id),
Expand All @@ -120,15 +117,13 @@ def get_household(country_id: str, household_id: str) -> dict:
)


@validate_country
def post_household(country_id: str) -> dict:
"""Set a household's input data.
Args:
country_id (str): The country ID.
"""
country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json
label = payload.get("label")
Expand Down Expand Up @@ -169,17 +164,14 @@ def post_household(country_id: str) -> dict:
)


@validate_country
def update_household(country_id: str, household_id: str) -> Response:
"""
Update a household via UPDATE request
Args: country_id (str): The country ID
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

# Fetch existing household first
try:
row = database.query(
Expand Down Expand Up @@ -258,6 +250,7 @@ def update_household(country_id: str, household_id: str) -> Response:
)


@validate_country
def get_household_under_policy(
country_id: str, household_id: str, policy_id: str
):
Expand All @@ -268,9 +261,6 @@ def get_household_under_policy(
household_id (str): The household ID.
policy_id (str): The policy ID.
"""
invalid_country = validate_country(country_id)
if invalid_country:
return invalid_country

api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id)

Expand Down Expand Up @@ -393,17 +383,14 @@ def get_household_under_policy(
)


@validate_country
def get_calculate(country_id: str, add_missing: bool = False) -> dict:
"""Lightweight endpoint for passing in household and policy JSON objects and calculating without storing data.
Args:
country_id (str): The country ID.
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json
household_json = payload.get("household", {})
policy_json = payload.get("policy", {})
Expand Down
5 changes: 1 addition & 4 deletions policyengine_api/endpoints/metadata.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from policyengine_api.country import COUNTRIES, validate_country


@validate_country
def get_metadata(country_id: str) -> dict:
"""Get metadata for a country.
Args:
country_id (str): The country ID.
"""
invalid_country = validate_country(country_id)
if invalid_country:
return invalid_country

return COUNTRIES.get(country_id).metadata
29 changes: 8 additions & 21 deletions policyengine_api/endpoints/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sqlalchemy.exc


@validate_country
def get_policy(country_id: str, policy_id: int) -> dict:
"""
Get policy data for a given country and policy ID.
Expand All @@ -21,9 +22,7 @@ def get_policy(country_id: str, policy_id: int) -> dict:
Returns:
dict: The policy record.
"""
country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

# Get the policy record for a given policy ID.
row = database.query(
f"SELECT * FROM policy WHERE country_id = ? AND id = ?",
Expand All @@ -48,6 +47,7 @@ def get_policy(country_id: str, policy_id: int) -> dict:
)


@validate_country
def set_policy(
country_id: str,
) -> dict:
Expand All @@ -59,9 +59,6 @@ def set_policy(
Args:
country_id (str): The country ID.
"""
country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json
label = payload.pop("label", None)
Expand Down Expand Up @@ -176,6 +173,7 @@ def set_policy(
)


@validate_country
def get_policy_search(country_id: str) -> dict:
"""
Search for policies for a specified country
Expand All @@ -195,17 +193,14 @@ def get_policy_search(country_id: str) -> dict:
Example:
GET /api/policies/us?query=tax&unique_only=true
"""

query = request.args.get("query", "")
# The "json.loads" default type is added to convert lowercase
# "true" and "false" to Python-friendly bool values
unique_only = request.args.get(
"unique_only", default=False, type=json.loads
)

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

try:
results = database.query(
"SELECT id, label, policy_hash FROM policy WHERE country_id = ? AND label LIKE ?",
Expand Down Expand Up @@ -269,6 +264,7 @@ def get_current_law_policy_id(country_id: str) -> int:
}[country_id]


@validate_country
def set_user_policy(country_id: str) -> dict:
"""
Adds a record (if unique, barring type) to the user_policy table
Expand All @@ -277,10 +273,6 @@ def set_user_policy(country_id: str) -> dict:
is currently unused
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json
reform_label = payload.pop("reform_label", None)
reform_id = payload.pop("reform_id")
Expand Down Expand Up @@ -436,14 +428,12 @@ def set_user_policy(country_id: str) -> dict:
)


@validate_country
def get_user_policy(country_id: str, user_id: str) -> dict:
"""
Fetch all saved user policies by user id
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found
# Get the policy record for a given policy ID.
rows = database.query(
f"SELECT * FROM user_policies WHERE country_id = ? AND user_id = ?",
Expand Down Expand Up @@ -488,15 +478,12 @@ def get_user_policy(country_id: str, user_id: str) -> dict:
)


@validate_country
def update_user_policy(country_id: str) -> dict:
"""
Update any parts of a user_policy, given a user_policy ID
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

# Construct the relevant UPDATE request
setter_array = []
args = []
Expand Down
5 changes: 1 addition & 4 deletions policyengine_api/endpoints/tracer_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit


@validate_country
def execute_tracer_analysis(
country_id: str,
):
Expand All @@ -29,10 +30,6 @@ def execute_tracer_analysis(
country_id (str): The country ID.
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json

household_id = payload.get("household_id")
Expand Down
14 changes: 3 additions & 11 deletions policyengine_api/endpoints/user_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import json


@validate_country
def set_user_profile(country_id: str) -> dict:
"""
Creates a new user_profile
"""
country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json
primary_country = country_id
Expand Down Expand Up @@ -86,17 +84,14 @@ def set_user_profile(country_id: str) -> dict:
)


@validate_country
def get_user_profile(country_id: str) -> dict:
"""
Get a user profile in one of two ways: by auth0_id,
which returns all data, and by user_id, which returns
all data except auth0_id
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

if len(request.args) != 1:
return Response(
json.dumps(
Expand Down Expand Up @@ -175,17 +170,14 @@ def get_user_profile(country_id: str) -> dict:
)


@validate_country
def update_user_profile(country_id: str) -> dict:
"""
Update any part of a user_profile, given a user_id,
except the auth0_id value; any attempt to edit this
will assume malicious intent and 403
"""

country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

# Construct the relevant UPDATE request
setter_array = []
args = []
Expand Down
27 changes: 27 additions & 0 deletions tests/python/test_country.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from flask import Response
from policyengine_api.country import validate_country


@validate_country
def foo(country_id, other):
"""
A simple dummy test method for validation testing. Must be defined outside of the class (or within
the test functions themselves) due to complications with the `self` parameter for class methods.
"""
return "bar"

class TestValidateCountry:
"""
Test that the @validate_country decorator returns 404 if the country does not exist, otherwise
continues execution of the function.
"""

def test_valid_country(self):
result = foo("us", "extra_arg")
assert result == "bar"

def test_invalid_country(self):
result = foo("baz", "extra_arg")
assert isinstance(result, Response)
assert result.status_code == 404

Loading

0 comments on commit f4ea3f7

Please sign in to comment.