Skip to content

Commit

Permalink
chore: Lint and changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
anth-volk committed Nov 21, 2024
1 parent 26d254b commit 865250f
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 196 deletions.
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: minor
changes:
changed:
- Refactored AI endpoints to match new routes/services/jobs architecture
- Disabled default buffering on App Engine deployments for AI endpoints
14 changes: 9 additions & 5 deletions policyengine_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

# Endpoints
from policyengine_api.routes.economy_routes import economy_bp
from policyengine_api.routes.simulation_analysis_routes import simulation_analysis_bp
from policyengine_api.routes.simulation_analysis_routes import (
simulation_analysis_bp,
)
from policyengine_api.routes.tracer_analysis_routes import tracer_analysis_bp
from .endpoints import (
get_home,
Expand Down Expand Up @@ -95,9 +97,8 @@

# Routes for AI analysis of economy microsim runs
app.register_blueprint(
simulation_analysis_bp,
url_prefix="/<country_id>/simulation-analysis"
)
simulation_analysis_bp, url_prefix="/<country_id>/simulation-analysis"
)

app.route("/<country_id>/user-policy", methods=["POST"])(set_user_policy)

Expand All @@ -115,7 +116,10 @@

app.route("/simulations", methods=["GET"])(get_simulations)

app.register_blueprint(tracer_analysis_bp, url_prefix="/<country_id>/tracer-analysis")
app.register_blueprint(
tracer_analysis_bp, url_prefix="/<country_id>/tracer-analysis"
)


@app.route("/liveness-check", methods=["GET"])
def liveness_check():
Expand Down
19 changes: 13 additions & 6 deletions policyengine_api/routes/simulation_analysis_routes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from flask import Blueprint, request, Response, stream_with_context
from policyengine_api.helpers import validate_country
from policyengine_api.services.simulation_analysis_service import SimulationAnalysisService
from policyengine_api.services.simulation_analysis_service import (
SimulationAnalysisService,
)

simulation_analysis_bp = Blueprint("simulation_analysis", __name__)
simulation_analysis_service = SimulationAnalysisService()


@simulation_analysis_bp.route("", methods=["POST"])
def execute_simulation_analysis(country_id):
print("Got POST request for simulation analysis")
Expand All @@ -20,7 +23,9 @@ def execute_simulation_analysis(country_id):

is_payload_valid, message = validate_payload(payload)
if not is_payload_valid:
return Response(status=400, response=f"Invalid JSON data; details: {message}")
return Response(
status=400, response=f"Invalid JSON data; details: {message}"
)

currency: str = payload.get("currency")
selected_version: str = payload.get("selected_version")
Expand Down Expand Up @@ -58,16 +63,18 @@ def execute_simulation_analysis(country_id):

# Set header to prevent buffering on Google App Engine deployment
# (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering)
response.headers['X-Accel-Buffering'] = 'no'
response.headers["X-Accel-Buffering"] = "no"

return response
except Exception as e:
return {
"status": "error",
"message": "An error occurred while executing the simulation analysis. Details: " + str(e),
"message": "An error occurred while executing the simulation analysis. Details: "
+ str(e),
"result": None,
}, 500


def validate_payload(payload: dict):
# Check if all required keys are present; note
# that the audience key is optional
Expand Down Expand Up @@ -97,7 +104,7 @@ def validate_payload(payload: dict):
missing_keys = [key for key in required_keys if key not in payload]
if missing_keys:
return False, f"Missing required keys: {missing_keys}"

# Check if all keys are of the right type
for key, value in payload.items():
if key in str_keys and not isinstance(value, str):
Expand All @@ -107,4 +114,4 @@ def validate_payload(payload: dict):
elif key in list_keys and not isinstance(value, list):
return False, f"Key '{key}' must be a list"

return True, None
return True, None
26 changes: 16 additions & 10 deletions policyengine_api/routes/tracer_analysis_routes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from flask import Blueprint, request, Response, stream_with_context
from policyengine_api.helpers import validate_country
from policyengine_api.services.tracer_analysis_service import TracerAnalysisService
from policyengine_api.services.tracer_analysis_service import (
TracerAnalysisService,
)

tracer_analysis_bp = Blueprint("tracer_analysis", __name__)
tracer_analysis_service = TracerAnalysisService()


@tracer_analysis_bp.route("", methods=["POST"])
def execute_tracer_analysis(country_id):

Expand All @@ -17,7 +20,9 @@ def execute_tracer_analysis(country_id):

is_payload_valid, message = validate_payload(payload)
if not is_payload_valid:
return Response(status=400, response=f"Invalid JSON data; details: {message}")
return Response(
status=400, response=f"Invalid JSON data; details: {message}"
)

household_id = payload.get("household_id")
policy_id = payload.get("policy_id")
Expand All @@ -36,20 +41,21 @@ def execute_tracer_analysis(country_id):
),
status=200,
)

# Set header to prevent buffering on Google App Engine deployment
# (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering)
response.headers['X-Accel-Buffering'] = 'no'
response.headers["X-Accel-Buffering"] = "no"

return response
except Exception as e:
return {
"status": "error",
"message": "An error occurred while executing the tracer analysis. Details: "
+ str(e),
"result": None,
"status": "error",
"message": "An error occurred while executing the tracer analysis. Details: "
+ str(e),
"result": None,
}, 500


def validate_payload(payload: dict):
# Validate payload
if not payload:
Expand All @@ -60,4 +66,4 @@ def validate_payload(payload: dict):
if key not in payload:
return False, f"Missing required key: {key}"

return True, None
return True, None
13 changes: 9 additions & 4 deletions policyengine_api/services/ai_analysis_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from typing import Generator
from policyengine_api.data import local_database


class AIAnalysisService:
"""
Base class for various AI analysis-based services,
including SimulationAnalysisService, that connects with the analysis
local database table
"""

def get_existing_analysis(self, prompt: str) -> Generator[str, None, None] | None:
def get_existing_analysis(
self, prompt: str
) -> Generator[str, None, None] | None:
"""
Get existing analysis from the local database
"""
Expand Down Expand Up @@ -43,9 +46,11 @@ def generate():
return generate()

def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]:

# Configure a Claude client
claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
claude_client = anthropic.Anthropic(
api_key=os.getenv("ANTHROPIC_API_KEY")
)

def generate():
chunk_size = 5
Expand Down Expand Up @@ -84,4 +89,4 @@ def generate():
(prompt, response_text, "ok"),
)

return generate()
return generate()
1 change: 1 addition & 0 deletions policyengine_api/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class EconomyService:
to the /economy route, which does not have its own table; therefore, it connects
with other services to access their respective tables
"""

def get_economic_impact(
self,
country_id,
Expand Down
1 change: 1 addition & 0 deletions policyengine_api/services/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class JobService(metaclass=Singleton):
jobs. This is not connected to any routes or tables, but interfaces
with the Redis queue to enqueue jobs and track their status.
"""

def __init__(self):
self.recent_jobs = {}

Expand Down
1 change: 1 addition & 0 deletions policyengine_api/services/policy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class PolicyService:
this will be connected to the /policy route and is partially connected
to the policy database table
"""

def get_policy_json(self, country_id, policy_id):
try:
policy_json = database.query(
Expand Down
146 changes: 75 additions & 71 deletions policyengine_api/services/simulation_analysis_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,74 @@

from policyengine_api.services.ai_analysis_service import AIAnalysisService


class SimulationAnalysisService(AIAnalysisService):
"""
Service for generating AI analysis of economy-wide simulation
runs; this is connected with the simulation_analysis route and
analysis database table
"""
def __init__(self):
super().__init__()

def execute_analysis(
self,
country_id: str,
currency: str,
selected_version: str,
time_period: str,
impact: dict,
policy_label: str,
policy: dict,
region: str,
relevant_parameters: list,
relevant_parameter_baseline_values: list,
audience: str | None,
):

# Check if the region is enhanced_cps
is_enhanced_cps = "enhanced_cps" in region

print("Generating prompt for economy-wide simulation analysis")

# Create prompt based on data
prompt = self._generate_simulation_analysis_prompt(
"""
Service for generating AI analysis of economy-wide simulation
runs; this is connected with the simulation_analysis route and
analysis database table
"""

def __init__(self):
super().__init__()

def execute_analysis(
self,
country_id: str,
currency: str,
selected_version: str,
time_period: str,
impact: dict,
policy_label: str,
policy: dict,
region: str,
relevant_parameters: list,
relevant_parameter_baseline_values: list,
audience: str | None,
):

# Check if the region is enhanced_cps
is_enhanced_cps = "enhanced_cps" in region

print("Generating prompt for economy-wide simulation analysis")

# Create prompt based on data
prompt = self._generate_simulation_analysis_prompt(
time_period,
region,
currency,
policy,
impact,
relevant_parameters,
relevant_parameter_baseline_values,
is_enhanced_cps,
selected_version,
country_id,
policy_label,
)

# Add audience description to end
prompt += self.audience_descriptions[audience]

print("Checking if AI analysis already exists for this prompt")
# If a calculated record exists for this prompt, return it as a
# streaming response
existing_analysis = self.get_existing_analysis(prompt)
if existing_analysis is not None:
return existing_analysis

print(
"Found no existing AI analysis; triggering new analysis with Claude"
)
# Otherwise, pass prompt to Claude, then return streaming function
try:
analysis = self.trigger_ai_analysis(prompt)
return analysis
except Exception as e:
raise e

def _generate_simulation_analysis_prompt(
self,
time_period,
region,
currency,
Expand All @@ -44,41 +81,8 @@ def execute_analysis(
selected_version,
country_id,
policy_label,
)

# Add audience description to end
prompt += self.audience_descriptions[audience]

print("Checking if AI analysis already exists for this prompt")
# If a calculated record exists for this prompt, return it as a
# streaming response
existing_analysis = self.get_existing_analysis(prompt)
if existing_analysis is not None:
return existing_analysis

print("Found no existing AI analysis; triggering new analysis with Claude")
# Otherwise, pass prompt to Claude, then return streaming function
try:
analysis = self.trigger_ai_analysis(prompt)
return analysis
except Exception as e:
raise e

def _generate_simulation_analysis_prompt(
self,
time_period,
region,
currency,
policy,
impact,
relevant_parameters,
relevant_parameter_baseline_values,
is_enhanced_cps,
selected_version,
country_id,
policy_label,
):
return f"""
):
return f"""
I'm using PolicyEngine, a free, open source tool to compute the impact of
public policy. I'm writing up an economic analysis of a hypothetical tax-benefit
policy reform. Please write the analysis for me using the details below, in
Expand Down Expand Up @@ -184,9 +188,9 @@ def _generate_simulation_analysis_prompt(
impact["inequality"],
)}
"""
audience_descriptions = {
"ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.",
"Normal": "Write this for a policy analyst who knows a bit about economics and policy.",
"Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.",
}

audience_descriptions = {
"ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.",
"Normal": "Write this for a policy analyst who knows a bit about economics and policy.",
"Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.",
}
Loading

0 comments on commit 865250f

Please sign in to comment.