diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index a1e922b4..450aea00 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -12,7 +12,7 @@ @validate_country @economy_bp.route("//over/", methods=["GET"]) -def get_economic_impact(country_id, policy_id, baseline_policy_id): +def get_economic_impact(country_id: str, policy_id: str | int, baseline_policy_id: str | int): policy_id = int(policy_id or get_current_law_policy_id(country_id)) baseline_policy_id = int( diff --git a/policyengine_api/routes/policy_routes.py b/policyengine_api/routes/policy_routes.py index 762c63e7..2b645f57 100644 --- a/policyengine_api/routes/policy_routes.py +++ b/policyengine_api/routes/policy_routes.py @@ -39,9 +39,6 @@ def get_policy(country_id: str, policy_id: int | str) -> Response: policy_id = int(policy_id) policy: dict | None = policy_service.get_policy(country_id, policy_id) - print(policy) - print(type(policy)) - print(type(policy["policy_json"])) if policy is None: return Response( diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 66beaceb..3a59a766 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -21,15 +21,18 @@ class EconomyService: def get_economic_impact( self, - country_id, - policy_id, - baseline_policy_id, - region, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, dataset, - time_period, - options, - api_version, + time_period: str, + options: dict, + api_version: str, ): + """ + Calculate the society-wide economic impact of a policy reform. + """ try: # Note for anyone modifying options_hash: redis-queue treats ":" as a namespace # delimiter; don't use colons in options_hash @@ -172,15 +175,19 @@ def get_economic_impact( def _get_previous_impacts( self, - country_id, - policy_id, - baseline_policy_id, - region, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, dataset, - time_period, - options_hash, - api_version, + time_period: str, + options_hash: str, + api_version: str, ): + """ + Fetch any previous simulation runs for the given policy reform. + """ + previous_impacts = reform_impacts_service.get_all_reform_impacts( country_id, policy_id, @@ -203,16 +210,19 @@ def _get_previous_impacts( def _set_impact_computing( self, - country_id, - policy_id, - baseline_policy_id, - region, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, dataset, - time_period, - options, - options_hash, - api_version, + time_period: str, + options: dict, + options_hash: str, + api_version: str, ): + """ + In the reform_impact table, set the status of the impact to "computing". + """ try: reform_impacts_service.set_reform_impact( country_id, diff --git a/policyengine_api/services/policy_service.py b/policyengine_api/services/policy_service.py index 81009671..80cf98b6 100644 --- a/policyengine_api/services/policy_service.py +++ b/policyengine_api/services/policy_service.py @@ -1,4 +1,5 @@ import json +from sqlalchemy.engine.row import LegacyRow from policyengine_api.data import database from policyengine_api.utils import hash_object @@ -23,15 +24,18 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None: Returns dict | None -- the policy data, or None if not found """ + print("Getting policy") + try: # If no policy found, this will return None - policy = database.query( + row: LegacyRow = database.query( "SELECT * FROM policy WHERE country_id = ? AND id = ?", (country_id, policy_id), ).fetchone() - # Note that policy_json field stored as string in database; - # must be converted before handing back + # policy_json is JSON and must be loaded, if present; to enable, + # we must convert the row to a dictionary + policy = dict(row) if policy and policy["policy_json"]: policy["policy_json"] = json.loads(policy["policy_json"]) return policy @@ -39,7 +43,11 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None: print(f"Error getting policy: {str(e)}") raise e - def get_policy_json(self, country_id, policy_id): + def get_policy_json(self, country_id: str, policy_id: int): + """ + Fetch policy JSON based only on policy ID and country ID + """ + print("Getting policy json") try: policy_json = database.query( f"SELECT policy_json FROM policy WHERE country_id = ? AND id = ?", @@ -65,18 +73,21 @@ def set_policy( tuple[int, str, bool] -- the new policy ID, a message, and whether or not the policy already existed """ + print("Setting new policy") try: policy_hash = hash_object(policy_json) api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) # Check if policy already exists + print("Checking if policy exists") existing_policy = self._get_unique_policy_with_label( country_id, policy_hash, label ) # If so, pass appropriate values back if existing_policy: + print("Policy already exists") existing_policy_id = str(existing_policy["id"]) message = ( "Warning: Record created previously with this label. To create " @@ -86,6 +97,7 @@ def set_policy( return existing_policy["id"], "Policy already exists", True # Otherwise, insert the new policy... + print("Policy does not exist; creating new policy") self._create_new_policy( country_id, policy_json, policy_hash, label, api_version ) diff --git a/tests/python/test_policy_service.py b/tests/python/test_policy_service.py index 1d69f8ae..ee7d86f3 100644 --- a/tests/python/test_policy_service.py +++ b/tests/python/test_policy_service.py @@ -42,6 +42,9 @@ def test_get_policy_success( # Test result = policy_service.get_policy("us", self.test_policy_id) + print(result) + print(result["policy_json"]) + print(type(result["policy_json"])) # Verify assert result is not None assert isinstance(result["policy_json"], dict) diff --git a/tests/python/test_yearly_var_removal.py b/tests/python/test_yearly_var_removal.py index 34fa6e42..34db4b49 100644 --- a/tests/python/test_yearly_var_removal.py +++ b/tests/python/test_yearly_var_removal.py @@ -253,14 +253,19 @@ def test_get_calculate(client): ) as f: test_household = json.load(f) - test_policy = policy_service.get_policy("us", CURRENT_LAW_US)[ - "policy_json" - ] + # Current law is represented by empty dict/empty JSON + test_policy = {} + print(test_policy) + print(type(test_policy)) test_object["policy"] = test_policy test_object["household"] = test_household res = client.post("/us/calculate-full", json=test_object) + print(res) + print(res.text) + print(json.loads(res.text)) + print(json.loads(res.text)["result"]) result_object = json.loads(res.text)["result"] # Create a dict of entity singular and plural terms for testing