Skip to content

Commit

Permalink
fix: Add type hinting and proper type handling to various services
Browse files Browse the repository at this point in the history
  • Loading branch information
anth-volk committed Nov 27, 2024
1 parent 80be8f8 commit b46cc3a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 33 deletions.
2 changes: 1 addition & 1 deletion policyengine_api/routes/economy_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@validate_country
@economy_bp.route("/<policy_id>/over/<baseline_policy_id>", methods=["GET"])
def get_economic_impact(country_id, policy_id, baseline_policy_id):
def get_economic_impact(country_id: str, policy_id: str | int, baseline_policy_id: str | int):

policy_id = int(policy_id or get_current_law_policy_id(country_id))
baseline_policy_id = int(
Expand Down
3 changes: 0 additions & 3 deletions policyengine_api/routes/policy_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def get_policy(country_id: str, policy_id: int | str) -> Response:
policy_id = int(policy_id)

policy: dict | None = policy_service.get_policy(country_id, policy_id)
print(policy)
print(type(policy))
print(type(policy["policy_json"]))

if policy is None:
return Response(
Expand Down
54 changes: 32 additions & 22 deletions policyengine_api/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ class EconomyService:

def get_economic_impact(
self,
country_id,
policy_id,
baseline_policy_id,
region,
country_id: str,
policy_id: int,
baseline_policy_id: int,
region: str,
dataset,
time_period,
options,
api_version,
time_period: str,
options: dict,
api_version: str,
):
"""
Calculate the society-wide economic impact of a policy reform.
"""
try:
# Note for anyone modifying options_hash: redis-queue treats ":" as a namespace
# delimiter; don't use colons in options_hash
Expand Down Expand Up @@ -172,15 +175,19 @@ def get_economic_impact(

def _get_previous_impacts(
self,
country_id,
policy_id,
baseline_policy_id,
region,
country_id: str,
policy_id: int,
baseline_policy_id: int,
region: str,
dataset,
time_period,
options_hash,
api_version,
time_period: str,
options_hash: str,
api_version: str,
):
"""
Fetch any previous simulation runs for the given policy reform.
"""

previous_impacts = reform_impacts_service.get_all_reform_impacts(
country_id,
policy_id,
Expand All @@ -203,16 +210,19 @@ def _get_previous_impacts(

def _set_impact_computing(
self,
country_id,
policy_id,
baseline_policy_id,
region,
country_id: str,
policy_id: int,
baseline_policy_id: int,
region: str,
dataset,
time_period,
options,
options_hash,
api_version,
time_period: str,
options: dict,
options_hash: str,
api_version: str,
):
"""
In the reform_impact table, set the status of the impact to "computing".
"""
try:
reform_impacts_service.set_reform_impact(
country_id,
Expand Down
20 changes: 16 additions & 4 deletions policyengine_api/services/policy_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow

from policyengine_api.data import database
from policyengine_api.utils import hash_object
Expand All @@ -23,23 +24,30 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None:
Returns
dict | None -- the policy data, or None if not found
"""
print("Getting policy")

try:
# If no policy found, this will return None
policy = database.query(
row: LegacyRow = database.query(
"SELECT * FROM policy WHERE country_id = ? AND id = ?",
(country_id, policy_id),
).fetchone()

# Note that policy_json field stored as string in database;
# must be converted before handing back
# policy_json is JSON and must be loaded, if present; to enable,
# we must convert the row to a dictionary
policy = dict(row)
if policy and policy["policy_json"]:
policy["policy_json"] = json.loads(policy["policy_json"])
return policy
except Exception as e:
print(f"Error getting policy: {str(e)}")
raise e

def get_policy_json(self, country_id, policy_id):
def get_policy_json(self, country_id: str, policy_id: int):
"""
Fetch policy JSON based only on policy ID and country ID
"""
print("Getting policy json")
try:
policy_json = database.query(
f"SELECT policy_json FROM policy WHERE country_id = ? AND id = ?",
Expand All @@ -65,18 +73,21 @@ def set_policy(
tuple[int, str, bool] -- the new policy ID, a message, and whether or not
the policy already existed
"""
print("Setting new policy")

try:

policy_hash = hash_object(policy_json)
api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id)
# Check if policy already exists
print("Checking if policy exists")
existing_policy = self._get_unique_policy_with_label(
country_id, policy_hash, label
)

# If so, pass appropriate values back
if existing_policy:
print("Policy already exists")
existing_policy_id = str(existing_policy["id"])
message = (
"Warning: Record created previously with this label. To create "
Expand All @@ -86,6 +97,7 @@ def set_policy(
return existing_policy["id"], "Policy already exists", True

# Otherwise, insert the new policy...
print("Policy does not exist; creating new policy")
self._create_new_policy(
country_id, policy_json, policy_hash, label, api_version
)
Expand Down
3 changes: 3 additions & 0 deletions tests/python/test_policy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_get_policy_success(
# Test
result = policy_service.get_policy("us", self.test_policy_id)

print(result)
print(result["policy_json"])
print(type(result["policy_json"]))
# Verify
assert result is not None
assert isinstance(result["policy_json"], dict)
Expand Down
11 changes: 8 additions & 3 deletions tests/python/test_yearly_var_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,19 @@ def test_get_calculate(client):
) as f:
test_household = json.load(f)

test_policy = policy_service.get_policy("us", CURRENT_LAW_US)[
"policy_json"
]
# Current law is represented by empty dict/empty JSON
test_policy = {}
print(test_policy)
print(type(test_policy))

test_object["policy"] = test_policy
test_object["household"] = test_household

res = client.post("/us/calculate-full", json=test_object)
print(res)
print(res.text)
print(json.loads(res.text))
print(json.loads(res.text)["result"])
result_object = json.loads(res.text)["result"]

# Create a dict of entity singular and plural terms for testing
Expand Down

0 comments on commit b46cc3a

Please sign in to comment.