From 6531c5b64bdbe0a12132d24b351d226964dd8fc8 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 Nov 2024 19:40:10 +0100 Subject: [PATCH 01/14] feat: Refactor routes to Blueprint --- policyengine_api/api.py | 11 +++++++---- policyengine_api/routes/simulation_analysis_routes.py | 8 ++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 policyengine_api/routes/simulation_analysis_routes.py diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 70d409b8..e68eb2e8 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -14,6 +14,7 @@ # Endpoints from policyengine_api.routes.economy_routes import economy_bp +from policyengine_api.routes.simulation_analysis_routes import simulation_analysis_bp from .endpoints import ( get_home, get_metadata, @@ -25,7 +26,6 @@ get_policy_search, get_household_under_policy, get_calculate, - execute_simulation_analysis, set_user_policy, get_user_policy, update_user_policy, @@ -90,11 +90,14 @@ ) ) +# Routes for economy microsimulation app.register_blueprint(economy_bp, url_prefix="//economy") -app.route("//simulation-analysis", methods=["POST"])( - execute_simulation_analysis -) +# Routes for AI analysis of economy microsim runs +app.register_blueprint( + simulation_analysis_bp, + url_prefix="//simulation-analysis" + ) app.route("//user-policy", methods=["POST"])(set_user_policy) diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py new file mode 100644 index 00000000..a63f4ec5 --- /dev/null +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -0,0 +1,8 @@ +from flask import Blueprint + +simulation_analysis_bp = Blueprint("simulation_analysis", __name__) +from policyengine_api.endpoints.simulation_analysis import execute_simulation_analysis + +@simulation_analysis_bp.route("/", methods=["POST"]) +def execute_simulation_analysis_placeholder(country_id): + return execute_simulation_analysis(country_id) From 0ccd5a8f6bebe50d07fa5d5ac157f647bd334a45 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 Nov 2024 21:08:06 +0100 Subject: [PATCH 02/14] feat: Add new services, add comments where relevant --- policyengine_api/ai_prompts/simulation.py | 128 ------------ .../endpoints/simulation_analysis.py | 68 ------- .../routes/simulation_analysis_routes.py | 86 +++++++- policyengine_api/services/analysis_service.py | 87 ++++++++ policyengine_api/services/economy_service.py | 5 + policyengine_api/services/job_service.py | 5 + policyengine_api/services/policy_service.py | 5 + .../services/reform_impacts_service.py | 6 + .../services/simulation_analysis_service.py | 192 ++++++++++++++++++ 9 files changed, 383 insertions(+), 199 deletions(-) delete mode 100644 policyengine_api/ai_prompts/simulation.py delete mode 100644 policyengine_api/endpoints/simulation_analysis.py create mode 100644 policyengine_api/services/analysis_service.py create mode 100644 policyengine_api/services/simulation_analysis_service.py diff --git a/policyengine_api/ai_prompts/simulation.py b/policyengine_api/ai_prompts/simulation.py deleted file mode 100644 index fd9562b5..00000000 --- a/policyengine_api/ai_prompts/simulation.py +++ /dev/null @@ -1,128 +0,0 @@ -import json - -audience_descriptions = { - "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", - "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", - "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", -} - - -def generate_simulation_analysis_prompt( - time_period, - region, - currency, - policy, - impact, - relevant_parameters, - relevant_parameter_baseline_values, - is_enhanced_cps, - selected_version, - country_id, - policy_label, -): - return f""" - I'm using PolicyEngine, a free, open source tool to compute the impact of - public policy. I'm writing up an economic analysis of a hypothetical tax-benefit - policy reform. Please write the analysis for me using the details below, in - their order. You should: - - * First explain each provision of the reform, noting that it's hypothetical and - won't represents policy reforms for {time_period} and {region}. Explain how - the parameters are changing from the baseline to the reform values using the given data. - - {'''Explicitly mention that this analysis uses PolicyEngine Enhanced CPS, constructed - from the 2022 Current Population Survey and the 2015 IRS Public Use File, and calibrated - to tax, benefit, income, and demographic aggregates.''' if is_enhanced_cps else ''} - - * Round large numbers like: {currency}3.1 billion, {currency}300 million, - {currency}106,000, {currency}1.50 (never {currency}1.5). - - * Round percentages to one decimal place. - - * Avoid normative language like 'requires', 'should', 'must', and use quantitative statements - over general adjectives and adverbs. If you don't know what something is, don't make it up. - - * Avoid speculating about the intent of the policy or inferring any motives; only describe the - observable effects and impacts of the policy. Refrain from using subjective language or making - assumptions about the recipients and their needs. - - * Use the active voice where possible; for example, write phrases where the reform is the subject, - such as "the reform [or a description of the reform] reduces poverty by x%". - - * Use {'British' if country_id == 'uk' else 'American'} English spelling and grammar. - - * Cite PolicyEngine {country_id.upper()} v{selected_version} and the { - 'PolicyEngine-enhanced 2019 Family Resources Survey' if country_id == 'uk' else '2022 Current Population Survey March Supplement' - } microdata when describing policy impacts. - - * When describing poverty impacts, note that the poverty measure reported is { - 'absolute poverty before housing costs' if country_id == 'uk' else 'the Supplemental Poverty Measure' - } - - * Don't use headers, but do use Markdown formatting. Use - for bullets, and include a newline after each bullet. - - * Include the following embeds inline, without a header so it flows. - - * Immediately after you describe the changes by decile, include the text: '{{distributionalImpact.incomeDecile.relative}}' - - * And after the poverty rate changes, include the text: '{{povertyImpact.regular.byAge}}' - - {f"* After the racial breakdown of poverty rate changes, include the text: '{{povertyImpact.regular.byRace}}'" if country_id == 'us' else ''} - - * And after the inequality changes, include the text: '{{inequalityImpact}}' - - * Make sure to accurately represent the changes observed in the data. - - This JSON snippet describes the default parameter values: {json.dumps( - relevant_parameter_baseline_values, - )}\n - - This JSON snippet describes the baseline and reform policies being compared: {json.dumps( - policy, - )}\n`; - - {policy_label} has the following impacts from the PolicyEngine microsimulation model: - - This JSON snippet describes the relevant parameters with more details: {json.dumps( - relevant_parameters, - )} - - This JSON describes the total budgetary impact, the change to tax revenues and benefit - spending (ignore 'households' and 'baseline_net_income': {json.dumps( - impact["budget"], - )} - - This JSON describes how common different outcomes were at each income decile: {json.dumps( - impact["intra_decile"], - )} - - This JSON describes the average and relative changes to income by each income decile: {json.dumps( - impact["decile"], - )} - - This JSON describes the baseline and reform poverty rates by age group (describe the relative changes): {json.dumps( - impact["poverty"]["poverty"], - )} - - This JSON describes the baseline and reform deep poverty rates by age group - (describe the relative changes): {json.dumps( - impact["poverty"]["deep_poverty"], - )} - - This JSON describes the baseline and reform poverty and deep poverty rates - by gender (briefly describe the relative changes): {json.dumps( - impact["poverty_by_gender"], - )} - - { - '''This JSON describes the baseline and reform poverty impacts by racial group (briefly - describe the relative changes): {json.dumps(impact["poverty_by_race"]["poverty"])}''' - if country_id == "us" else "" - } - - This JSON describes three inequality metrics in the baseline and reform, the Gini - coefficient of income inequality, the share of income held by the top 10% of households - and the share held by the top 1% (describe the relative changes): {json.dumps( - impact["inequality"], - )} - """ diff --git a/policyengine_api/endpoints/simulation_analysis.py b/policyengine_api/endpoints/simulation_analysis.py deleted file mode 100644 index 1b5c27a3..00000000 --- a/policyengine_api/endpoints/simulation_analysis.py +++ /dev/null @@ -1,68 +0,0 @@ -from flask import request, Response -from policyengine_api.utils.ai_analysis import ( - trigger_ai_analysis, - get_existing_analysis, -) -from policyengine_api.ai_prompts import ( - generate_simulation_analysis_prompt, - audience_descriptions, -) - - -def execute_simulation_analysis(country_id: str) -> Response: - - # Pop the various parameters from the request - payload = request.json - - currency = payload.get("currency") - selected_version = payload.get("selected_version") - time_period = payload.get("time_period") - impact = payload.get("impact") - policy_label = payload.get("policy_label") - policy = payload.get("policy") - region = payload.get("region") - relevant_parameters = payload.get("relevant_parameters") - relevant_parameter_baseline_values = payload.get( - "relevant_parameter_baseline_values" - ) - audience = payload.get("audience") - - # Check if the region is enhanced_cps - is_enhanced_cps = "enhanced_cps" in region - - # Create prompt based on data - prompt = generate_simulation_analysis_prompt( - time_period, - region, - currency, - policy, - impact, - relevant_parameters, - relevant_parameter_baseline_values, - is_enhanced_cps, - selected_version, - country_id, - policy_label, - ) - - # Add audience description to end - prompt += audience_descriptions[audience] - - # If a calculated record exists for this prompt, return it as a - # streaming response - existing_analysis = get_existing_analysis(prompt) - if existing_analysis is not None: - return Response(status=200, response=existing_analysis) - - # Otherwise, pass prompt to Claude, then return streaming function - try: - analysis = trigger_ai_analysis(prompt) - return Response(status=200, response=analysis) - except Exception as e: - return Response( - status=500, - response={ - "message": "Error computing analysis", - "error": str(e), - }, - ) diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index a63f4ec5..902d7827 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,8 +1,88 @@ -from flask import Blueprint +from flask import Blueprint, request, Response +from policyengine_api.helpers import validate_country +from policyengine_api.services.simulation_analysis_service import SimulationAnalysisService simulation_analysis_bp = Blueprint("simulation_analysis", __name__) -from policyengine_api.endpoints.simulation_analysis import execute_simulation_analysis +simulation_analysis_servce = SimulationAnalysisService() @simulation_analysis_bp.route("/", methods=["POST"]) def execute_simulation_analysis_placeholder(country_id): - return execute_simulation_analysis(country_id) + print("Got POST request for simulation analysis") + + # Validate inbound country ID + invalid_country = validate_country(country_id) + if invalid_country: + return invalid_country + + # Pop items from request payload and validate + # where necessary + payload = request.json + + is_payload_valid, message = validate_payload(payload) + if not is_payload_valid: + return Response(status=400, response=f"Invalid JSON data; details: {message}") + + currency = payload.get("currency") + selected_version = payload.get("selected_version") + time_period = payload.get("time_period") + impact = payload.get("impact") + policy_label = payload.get("policy_label") + policy = payload.get("policy") + region = payload.get("region") + relevant_parameters = payload.get("relevant_parameters") + relevant_parameter_baseline_values = payload.get( + "relevant_parameter_baseline_values" + ) + audience = payload.get("audience", "") + + try: + analysis = simulation_analysis_servce.execute_simulation_analysis( + country_id, + currency, + selected_version, + time_period, + impact, + policy_label, + policy, + region, + relevant_parameters, + relevant_parameter_baseline_values, + audience, + ) + + return Response(status=200, response=analysis) + except Exception as e: + return ( + dict( + status="error", + message="An error occurred while executing the simulation analysis. Details: " + + str(e), + result=None, + ), + 500, + ) + +def validate_payload(payload: dict): + # Check if all required keys are present; note + # that the audience key is optional + required_keys = [ + "currency", + "selected_version", + "time_period", + "impact", + "policy_label", + "policy", + "region", + "relevant_parameters", + "relevant_parameter_baseline_values", + ] + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + return False, f"Missing required keys: {missing_keys}" + + # Check if all keys are of the right type + for key, value in payload.items(): + if not isinstance(value, str): + return False, f"Key '{key}' must be a string" + + return True, None \ No newline at end of file diff --git a/policyengine_api/services/analysis_service.py b/policyengine_api/services/analysis_service.py new file mode 100644 index 00000000..14714da7 --- /dev/null +++ b/policyengine_api/services/analysis_service.py @@ -0,0 +1,87 @@ +import anthropic +import os +import json +import time +from typing import Generator +from policyengine_api.data import local_database + +class AIAnalysisService: + """ + Base class for various AI analysis-based services, + including SimulationAnalysisService, that connects with the analysis + local database table + """ + + def get_existing_analysis(prompt: str) -> Generator[str, None, None] | None: + """ + Get existing analysis from the local database + """ + + analysis = local_database.query( + f"SELECT analysis FROM analysis WHERE prompt = ?", + (prompt,), + ).fetchone() + + if analysis is None: + return None + + def generate(): + + # First, yield prompt so it's accessible on front end + initial_data = { + "stream": "", + "prompt": prompt, + } + yield json.dumps(initial_data) + "\n" + + chunk_size = 5 + for i in range(0, len(analysis["analysis"]), chunk_size): + chunk = analysis["analysis"][i : i + chunk_size] + yield json.dumps({"stream": chunk}) + "\n" + time.sleep(0.05) + + return generate() + + def trigger_ai_analysis(prompt: str) -> Generator[str, None, None]: + + # Configure a Claude client + claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + + def generate(): + chunk_size = 5 + response_text = "" + buffer = "" + + # First, yield prompt so it's accessible on front end + initial_data = { + "stream": "", + "prompt": prompt, + } + + yield json.dumps(initial_data) + "\n" + + with claude_client.messages.stream( + model="claude-3-5-sonnet-20240620", + max_tokens=1500, + temperature=0.0, + system="Respond with a historical quote", + messages=[{"role": "user", "content": prompt}], + ) as stream: + for item in stream.text_stream: + buffer += item + response_text += item + while len(buffer) >= chunk_size: + chunk = buffer[:chunk_size] + buffer = buffer[chunk_size:] + yield json.dumps({"stream": chunk}) + "\n" + + if buffer: + yield json.dumps({"stream": buffer}) + "\n" + + # Finally, update the analysis record and return + local_database.query( + f"INSERT INTO analysis (prompt, analysis, status) VALUES (?, ?, ?)", + (prompt, response_text, "ok"), + ) + + return generate() \ No newline at end of file diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 0f13bfa4..da7bd34c 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -13,6 +13,11 @@ class EconomyService: + """ + Service for calculating economic impact of policy reforms; this is connected + to the /economy route, which does not have its own table; therefore, it connects + with other services to access their respective tables + """ def get_economic_impact( self, country_id, diff --git a/policyengine_api/services/job_service.py b/policyengine_api/services/job_service.py index 97b39870..f8d61c69 100644 --- a/policyengine_api/services/job_service.py +++ b/policyengine_api/services/job_service.py @@ -19,6 +19,11 @@ class JobStatus(Enum): class JobService(metaclass=Singleton): + """ + Hybrid service used to manage backend economy-wide simulation + jobs. This is not connected to any routes or tables, but interfaces + with the Redis queue to enqueue jobs and track their status. + """ def __init__(self): self.recent_jobs = {} diff --git a/policyengine_api/services/policy_service.py b/policyengine_api/services/policy_service.py index e2f2e813..4367ff25 100644 --- a/policyengine_api/services/policy_service.py +++ b/policyengine_api/services/policy_service.py @@ -2,6 +2,11 @@ class PolicyService: + """ + Partially-implemented service for storing and retrieving policies; + this will be connected to the /policy route and is partially connected + to the policy database table + """ def get_policy_json(self, country_id, policy_id): try: policy_json = database.query( diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 11524ff0..eb07fc4e 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -3,6 +3,12 @@ class ReformImpactsService: + """ + Service for storing and retrieving economy-wide reform impacts; + this is connected to the locally-stored reform_impact table + and no existing route + """ + def get_all_reform_impacts( self, country_id, diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py new file mode 100644 index 00000000..dc6444b8 --- /dev/null +++ b/policyengine_api/services/simulation_analysis_service.py @@ -0,0 +1,192 @@ +import json + +from policyengine_api.services.analysis_service import AnalysisService +from policyengine_api.data import local_database + +class SimulationAnalysisService(AnalysisService): + """ + Service for generating AI analysis of economy-wide simulation + runs; this is connected with the simulation_analysis route and + analysis database table + """ + def __init__(self): + super().__init__() + + def execute_simulation_analysis( + self, + country_id: str, + currency: str, + selected_version: str, + time_period: str, + impact: str, + policy_label: str, + policy: str, + region: str, + relevant_parameters: str, + relevant_parameter_baseline_values: str, + audience: str | None, + ): + + # Check if the region is enhanced_cps + is_enhanced_cps = "enhanced_cps" in region + + print("Generating prompt for economy-wide simulation analysis") + + # Create prompt based on data + prompt = self._generate_simulation_analysis_prompt( + time_period, + region, + currency, + policy, + impact, + relevant_parameters, + relevant_parameter_baseline_values, + is_enhanced_cps, + selected_version, + country_id, + policy_label, + ) + + # Add audience description to end + prompt += self.audience_descriptions[audience] + + print("Checking if AI analysis already exists for this prompt") + # If a calculated record exists for this prompt, return it as a + # streaming response + existing_analysis = self.get_existing_analysis(prompt) + if existing_analysis is not None: + return existing_analysis + + print("Found no existing AI analysis; triggering new analysis with Claude") + # Otherwise, pass prompt to Claude, then return streaming function + try: + analysis = self.trigger_ai_analysis(prompt) + return analysis + except Exception as e: + raise e + + def _generate_simulation_analysis_prompt( + time_period, + region, + currency, + policy, + impact, + relevant_parameters, + relevant_parameter_baseline_values, + is_enhanced_cps, + selected_version, + country_id, + policy_label, + ): + return f""" + I'm using PolicyEngine, a free, open source tool to compute the impact of + public policy. I'm writing up an economic analysis of a hypothetical tax-benefit + policy reform. Please write the analysis for me using the details below, in + their order. You should: + + * First explain each provision of the reform, noting that it's hypothetical and + won't represents policy reforms for {time_period} and {region}. Explain how + the parameters are changing from the baseline to the reform values using the given data. + + {'''Explicitly mention that this analysis uses PolicyEngine Enhanced CPS, constructed + from the 2022 Current Population Survey and the 2015 IRS Public Use File, and calibrated + to tax, benefit, income, and demographic aggregates.''' if is_enhanced_cps else ''} + + * Round large numbers like: {currency}3.1 billion, {currency}300 million, + {currency}106,000, {currency}1.50 (never {currency}1.5). + + * Round percentages to one decimal place. + + * Avoid normative language like 'requires', 'should', 'must', and use quantitative statements + over general adjectives and adverbs. If you don't know what something is, don't make it up. + + * Avoid speculating about the intent of the policy or inferring any motives; only describe the + observable effects and impacts of the policy. Refrain from using subjective language or making + assumptions about the recipients and their needs. + + * Use the active voice where possible; for example, write phrases where the reform is the subject, + such as "the reform [or a description of the reform] reduces poverty by x%". + + * Use {'British' if country_id == 'uk' else 'American'} English spelling and grammar. + + * Cite PolicyEngine {country_id.upper()} v{selected_version} and the { + 'PolicyEngine-enhanced 2019 Family Resources Survey' if country_id == 'uk' else '2022 Current Population Survey March Supplement' + } microdata when describing policy impacts. + + * When describing poverty impacts, note that the poverty measure reported is { + 'absolute poverty before housing costs' if country_id == 'uk' else 'the Supplemental Poverty Measure' + } + + * Don't use headers, but do use Markdown formatting. Use - for bullets, and include a newline after each bullet. + + * Include the following embeds inline, without a header so it flows. + + * Immediately after you describe the changes by decile, include the text: '{{distributionalImpact.incomeDecile.relative}}' + + * And after the poverty rate changes, include the text: '{{povertyImpact.regular.byAge}}' + + {f"* After the racial breakdown of poverty rate changes, include the text: '{{povertyImpact.regular.byRace}}'" if country_id == 'us' else ''} + + * And after the inequality changes, include the text: '{{inequalityImpact}}' + + * Make sure to accurately represent the changes observed in the data. + + This JSON snippet describes the default parameter values: {json.dumps( + relevant_parameter_baseline_values, + )}\n + + This JSON snippet describes the baseline and reform policies being compared: {json.dumps( + policy, + )}\n`; + + {policy_label} has the following impacts from the PolicyEngine microsimulation model: + + This JSON snippet describes the relevant parameters with more details: {json.dumps( + relevant_parameters, + )} + + This JSON describes the total budgetary impact, the change to tax revenues and benefit + spending (ignore 'households' and 'baseline_net_income': {json.dumps( + impact["budget"], + )} + + This JSON describes how common different outcomes were at each income decile: {json.dumps( + impact["intra_decile"], + )} + + This JSON describes the average and relative changes to income by each income decile: {json.dumps( + impact["decile"], + )} + + This JSON describes the baseline and reform poverty rates by age group (describe the relative changes): {json.dumps( + impact["poverty"]["poverty"], + )} + + This JSON describes the baseline and reform deep poverty rates by age group + (describe the relative changes): {json.dumps( + impact["poverty"]["deep_poverty"], + )} + + This JSON describes the baseline and reform poverty and deep poverty rates + by gender (briefly describe the relative changes): {json.dumps( + impact["poverty_by_gender"], + )} + + { + '''This JSON describes the baseline and reform poverty impacts by racial group (briefly + describe the relative changes): {json.dumps(impact["poverty_by_race"]["poverty"])}''' + if country_id == "us" else "" + } + + This JSON describes three inequality metrics in the baseline and reform, the Gini + coefficient of income inequality, the share of income held by the top 10% of households + and the share held by the top 1% (describe the relative changes): {json.dumps( + impact["inequality"], + )} + """ + + audience_descriptions = { + "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", + "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", + "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", + } \ No newline at end of file From 9490cbda1297e794b45a9c7c70344bc96b2dd24c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 Nov 2024 21:35:49 +0100 Subject: [PATCH 03/14] feat: TracerAnalysisService --- policyengine_api/ai_prompts/__init__.py | 5 - policyengine_api/ai_prompts/tracer.py | 16 --- policyengine_api/api.py | 7 +- policyengine_api/endpoints/__init__.py | 4 - policyengine_api/endpoints/tracer_analysis.py | 1 + policyengine_api/routes/economy_routes.py | 4 - .../routes/simulation_analysis_routes.py | 6 +- .../routes/tracer_analysis_routes.py | 56 ++++++++ .../services/simulation_analysis_service.py | 3 +- .../services/tracer_analysis_service.py | 130 ++++++++++++++++++ policyengine_api/utils/ai_analysis.py | 82 ----------- policyengine_api/utils/tracer_analysis.py | 31 ----- 12 files changed, 193 insertions(+), 152 deletions(-) delete mode 100644 policyengine_api/ai_prompts/__init__.py delete mode 100644 policyengine_api/ai_prompts/tracer.py create mode 100644 policyengine_api/routes/tracer_analysis_routes.py create mode 100644 policyengine_api/services/tracer_analysis_service.py delete mode 100644 policyengine_api/utils/ai_analysis.py delete mode 100644 policyengine_api/utils/tracer_analysis.py diff --git a/policyengine_api/ai_prompts/__init__.py b/policyengine_api/ai_prompts/__init__.py deleted file mode 100644 index a4219d0a..00000000 --- a/policyengine_api/ai_prompts/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from policyengine_api.ai_prompts.tracer import tracer_analysis_prompt -from policyengine_api.ai_prompts.simulation import ( - generate_simulation_analysis_prompt, - audience_descriptions, -) diff --git a/policyengine_api/ai_prompts/tracer.py b/policyengine_api/ai_prompts/tracer.py deleted file mode 100644 index 7f87f3a2..00000000 --- a/policyengine_api/ai_prompts/tracer.py +++ /dev/null @@ -1,16 +0,0 @@ -import anthropic - -tracer_analysis_prompt = f"""{anthropic.HUMAN_PROMPT} You are an AI assistant explaining US policy calculations. -The user has run a simulation for the variable '{{variable}}'. -Here's the tracer output: -{{tracer_segment}} - -Please explain this result in simple terms. Your explanation should: -1. Briefly describe what {{variable}} is. -2. Explain the main factors that led to this result. -3. Mention any key thresholds or rules that affected the calculation. -4. If relevant, suggest how changes in input might affect this result. - -Keep your explanation concise but informative, suitable for a general audience. Do not start with phrases like "Certainly!" or "Here's an explanation. It will be rendered as markdown, so preface $ with \. - -{anthropic.AI_PROMPT}""" diff --git a/policyengine_api/api.py b/policyengine_api/api.py index e68eb2e8..162fa00d 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -15,6 +15,7 @@ # Endpoints from policyengine_api.routes.economy_routes import economy_bp from policyengine_api.routes.simulation_analysis_routes import simulation_analysis_bp +from policyengine_api.routes.tracer_analysis_routes import tracer_analysis_bp from .endpoints import ( get_home, get_metadata, @@ -33,7 +34,6 @@ get_user_profile, update_user_profile, get_simulations, - execute_tracer_analysis, ) print("Initialising API...") @@ -115,10 +115,7 @@ app.route("/simulations", methods=["GET"])(get_simulations) -app.route("//tracer-analysis", methods=["POST"])( - execute_tracer_analysis -) - +app.register_blueprint(tracer_analysis_bp, url_prefix="//tracer-analysis") @app.route("/liveness-check", methods=["GET"]) def liveness_check(): diff --git a/policyengine_api/endpoints/__init__.py b/policyengine_api/endpoints/__init__.py index 44a6c7eb..cc42c81b 100644 --- a/policyengine_api/endpoints/__init__.py +++ b/policyengine_api/endpoints/__init__.py @@ -16,13 +16,9 @@ update_user_policy, ) -# from .economy import get_economic_impact -from .simulation_analysis import execute_simulation_analysis - from .user_profile import ( set_user_profile, get_user_profile, update_user_profile, ) from .simulation import get_simulations -from .tracer_analysis import execute_tracer_analysis diff --git a/policyengine_api/endpoints/tracer_analysis.py b/policyengine_api/endpoints/tracer_analysis.py index d5e178ae..68085888 100644 --- a/policyengine_api/endpoints/tracer_analysis.py +++ b/policyengine_api/endpoints/tracer_analysis.py @@ -92,3 +92,4 @@ def execute_tracer_analysis( "error": str(e), }, ) + diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index b6051b0b..7b800d19 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -52,7 +52,3 @@ def get_economic_impact(country_id, policy_id, baseline_policy_id): ), 500, ) - - # Run service to check if already calculated in local db - - # Service to diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 902d7827..e8442f23 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -3,10 +3,10 @@ from policyengine_api.services.simulation_analysis_service import SimulationAnalysisService simulation_analysis_bp = Blueprint("simulation_analysis", __name__) -simulation_analysis_servce = SimulationAnalysisService() +simulation_analysis_service = SimulationAnalysisService() @simulation_analysis_bp.route("/", methods=["POST"]) -def execute_simulation_analysis_placeholder(country_id): +def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") # Validate inbound country ID @@ -36,7 +36,7 @@ def execute_simulation_analysis_placeholder(country_id): audience = payload.get("audience", "") try: - analysis = simulation_analysis_servce.execute_simulation_analysis( + analysis = simulation_analysis_service.execute_analysis( country_id, currency, selected_version, diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py new file mode 100644 index 00000000..0c831efc --- /dev/null +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -0,0 +1,56 @@ +from flask import Blueprint, request, Response +from policyengine_api.helpers import validate_country +from policyengine_api.services.tracer_analysis_service import TracerAnalysisService + +tracer_analysis_bp = Blueprint("tracer_analysis", __name__) +tracer_analysis_service = TracerAnalysisService() + +@tracer_analysis_bp.route("/", methods=["POST"]) +def execute_tracer_analysis(country_id): + + # Validate country ID + country_not_found = validate_country(country_id) + if country_not_found: + return country_not_found + + payload = request.json + + is_payload_valid, message = validate_payload(payload) + if not is_payload_valid: + return Response(status=400, response=f"Invalid JSON data; details: {message}") + + household_id = payload.get("household_id") + policy_id = payload.get("policy_id") + variable = payload.get("variable") + + try: + analysis = tracer_analysis_service.execute_analysis( + country_id, + household_id, + policy_id, + variable, + ) + + return Response(status=200, response=analysis) + except Exception as e: + return ( + dict( + status="error", + message="An error occurred while executing the tracer analysis. Details: " + + str(e), + result=None, + ), + 500, + ) + +def validate_payload(payload: dict): + # Validate payload + if not payload: + return False, "No payload provided" + + required_keys = ["household_id", "policy_id", "variable"] + for key in required_keys: + if key not in payload: + return False, f"Missing required key: {key}" + + return True, None \ No newline at end of file diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index dc6444b8..da5352a3 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -1,7 +1,6 @@ import json from policyengine_api.services.analysis_service import AnalysisService -from policyengine_api.data import local_database class SimulationAnalysisService(AnalysisService): """ @@ -12,7 +11,7 @@ class SimulationAnalysisService(AnalysisService): def __init__(self): super().__init__() - def execute_simulation_analysis( + def execute_analysis( self, country_id: str, currency: str, diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py new file mode 100644 index 00000000..f6eb6d22 --- /dev/null +++ b/policyengine_api/services/tracer_analysis_service.py @@ -0,0 +1,130 @@ +from policyengine_api.data import local_database +import json +from flask import stream_with_context +from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS +from typing import Generator +import re +import anthropic +from policyengine_api.services.analysis_service import AnalysisService + +class TracerAnalysisService(AnalysisService): + def __init__(self): + super().__init__() + + def execute_analysis( + self, + country_id: str, + household_id: str, + policy_id: str, + variable: str, + ): + + api_version = COUNTRY_PACKAGE_VERSIONS[country_id] + + # Retrieve tracer record from table + try: + tracer: list[str] = self.get_tracer( + country_id, + household_id, + policy_id, + api_version, + ) + except Exception as e: + raise e + + # Parse the tracer output for our given variable + try: + tracer_segment: list[str] = self._parse_tracer_output(tracer, variable) + except Exception as e: + print(f"Error parsing tracer output: {str(e)}") + raise e + + # Add the parsed tracer output to the prompt + prompt = self.prompt_template.format( + variable=variable, tracer_segment=tracer_segment + ) + + # If a calculated record exists for this prompt, return it as a + # streaming response + existing_analysis: Generator = self.get_existing_analysis(prompt) + if existing_analysis is not None: + return stream_with_context(existing_analysis) + + # Otherwise, pass prompt to Claude, then return streaming function + try: + analysis: Generator = self.trigger_ai_analysis(prompt) + return stream_with_context(analysis) + except Exception as e: + print(f"Error generating AI analysis within tracer analysis service: {str(e)}") + raise e + + def get_tracer( + self, + country_id: str, + household_id: str, + policy_id: str, + api_version: str, + ) -> list: + try: + # Retrieve from the tracers table in the local database + row = local_database.query( + """ + SELECT * FROM tracers + WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ? + """, + (household_id, policy_id, country_id, api_version), + ).fetchone() + + if row is None: + raise KeyError("No tracer found for this household") + + tracer_output_list = json.loads(row["tracer_output"]) + return tracer_output_list + + except Exception as e: + print(f"Error getting existing tracer analysis: {str(e)}") + raise e + + def _parse_tracer_output(self, tracer_output, target_variable): + result = [] + target_indent = None + capturing = False + + # Create a regex pattern to match the exact variable name + # This will match the variable name followed by optional whitespace, + # then optional angle brackets with any content, then optional whitespace + pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" + + for line in tracer_output: + # Count leading spaces to determine indentation level + indent = len(line) - len(line.strip()) + + # Check if this line matches our target variable + match = re.match(pattern, line) + if match and not capturing: + target_indent = indent + capturing = True + result.append(line) + elif capturing: + # Stop capturing if we encounter a line with less indentation than the target + if indent <= target_indent: + break + # Capture dependencies (lines with greater indentation) + result.append(line) + + return result + + prompt_template = f"""{anthropic.HUMAN_PROMPT} You are an AI assistant explaining US policy calculations. + The user has run a simulation for the variable '{{variable}}'. + Here's the tracer output: + {{tracer_segment}} + + Please explain this result in simple terms. Your explanation should: + 1. Briefly describe what {{variable}} is. + 2. Explain the main factors that led to this result. + 3. Mention any key thresholds or rules that affected the calculation. + 4. If relevant, suggest how changes in input might affect this result. + + Keep your explanation concise but informative, suitable for a general audience. Do not start with phrases like "Certainly!" or "Here's an explanation. It will be rendered as markdown, so preface $ with \. + + {anthropic.AI_PROMPT}""" \ No newline at end of file diff --git a/policyengine_api/utils/ai_analysis.py b/policyengine_api/utils/ai_analysis.py deleted file mode 100644 index 2c329838..00000000 --- a/policyengine_api/utils/ai_analysis.py +++ /dev/null @@ -1,82 +0,0 @@ -import anthropic -import os -import time -from policyengine_api.data import local_database -from typing import Generator -import json - - -def trigger_ai_analysis(prompt: str) -> Generator[str, None, None]: - - # Configure a Claude client - claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) - - def generate(): - chunk_size = 5 - response_text = "" - buffer = "" - - # First, yield prompt so it's accessible on front end - initial_data = { - "stream": "", - "prompt": prompt, - } - - yield json.dumps(initial_data) + "\n" - - with claude_client.messages.stream( - model="claude-3-5-sonnet-20240620", - max_tokens=1500, - temperature=0.0, - system="Respond with a historical quote", - messages=[{"role": "user", "content": prompt}], - ) as stream: - for item in stream.text_stream: - buffer += item - response_text += item - while len(buffer) >= chunk_size: - chunk = buffer[:chunk_size] - buffer = buffer[chunk_size:] - yield json.dumps({"stream": chunk}) + "\n" - - if buffer: - yield json.dumps({"stream": buffer}) + "\n" - - # Finally, update the analysis record and return - local_database.query( - f"INSERT INTO analysis (prompt, analysis, status) VALUES (?, ?, ?)", - (prompt, response_text, "ok"), - ) - - return generate() - - -def get_existing_analysis(prompt: str) -> Generator[str, None, None] | None: - """ - Get existing analysis from the local database - """ - - analysis = local_database.query( - f"SELECT analysis FROM analysis WHERE prompt = ?", - (prompt,), - ).fetchone() - - if analysis is None: - return None - - def generate(): - - # First, yield prompt so it's accessible on front end - initial_data = { - "stream": "", - "prompt": prompt, - } - yield json.dumps(initial_data) + "\n" - - chunk_size = 5 - for i in range(0, len(analysis["analysis"]), chunk_size): - chunk = analysis["analysis"][i : i + chunk_size] - yield json.dumps({"stream": chunk}) + "\n" - time.sleep(0.05) - - return generate() diff --git a/policyengine_api/utils/tracer_analysis.py b/policyengine_api/utils/tracer_analysis.py deleted file mode 100644 index a9e595d4..00000000 --- a/policyengine_api/utils/tracer_analysis.py +++ /dev/null @@ -1,31 +0,0 @@ -import re - - -def parse_tracer_output(tracer_output, target_variable): - result = [] - target_indent = None - capturing = False - - # Create a regex pattern to match the exact variable name - # This will match the variable name followed by optional whitespace, - # then optional angle brackets with any content, then optional whitespace - pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" - - for line in tracer_output: - # Count leading spaces to determine indentation level - indent = len(line) - len(line.strip()) - - # Check if this line matches our target variable - match = re.match(pattern, line) - if match and not capturing: - target_indent = indent - capturing = True - result.append(line) - elif capturing: - # Stop capturing if we encounter a line with less indentation than the target - if indent <= target_indent: - break - # Capture dependencies (lines with greater indentation) - result.append(line) - - return result From 6386560e5872b4e67512bce12e413447d0e46804 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 Nov 2024 23:51:31 +0100 Subject: [PATCH 04/14] feat: Correct streaming for Google App Engine --- .../routes/simulation_analysis_routes.py | 68 ++++++++++++------- .../routes/tracer_analysis_routes.py | 41 ++++++----- ...ysis_service.py => ai_analysis_service.py} | 4 +- .../services/simulation_analysis_service.py | 13 ++-- .../services/tracer_analysis_service.py | 9 ++- 5 files changed, 82 insertions(+), 53 deletions(-) rename policyengine_api/services/{analysis_service.py => ai_analysis_service.py} (94%) diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index e8442f23..da7dfe77 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,11 +1,11 @@ -from flask import Blueprint, request, Response +from flask import Blueprint, request, Response, stream_with_context from policyengine_api.helpers import validate_country from policyengine_api.services.simulation_analysis_service import SimulationAnalysisService simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() -@simulation_analysis_bp.route("/", methods=["POST"]) +@simulation_analysis_bp.route("", methods=["POST"]) def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") @@ -22,15 +22,15 @@ def execute_simulation_analysis(country_id): if not is_payload_valid: return Response(status=400, response=f"Invalid JSON data; details: {message}") - currency = payload.get("currency") - selected_version = payload.get("selected_version") - time_period = payload.get("time_period") - impact = payload.get("impact") - policy_label = payload.get("policy_label") - policy = payload.get("policy") - region = payload.get("region") - relevant_parameters = payload.get("relevant_parameters") - relevant_parameter_baseline_values = payload.get( + currency: str = payload.get("currency") + selected_version: str = payload.get("selected_version") + time_period: str = payload.get("time_period") + impact: dict = payload.get("impact") + policy_label: str = payload.get("policy_label") + policy: dict = payload.get("policy") + region: str = payload.get("region") + relevant_parameters: list = payload.get("relevant_parameters") + relevant_parameter_baseline_values: list = payload.get( "relevant_parameter_baseline_values" ) audience = payload.get("audience", "") @@ -50,18 +50,24 @@ def execute_simulation_analysis(country_id): audience, ) - return Response(status=200, response=analysis) - except Exception as e: - return ( - dict( - status="error", - message="An error occurred while executing the simulation analysis. Details: " - + str(e), - result=None, - ), - 500, + # Create streaming response + response = Response( + stream_with_context(analysis), + status=200, ) + # Set header to prevent buffering on Google App Engine deployment + # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) + response.headers['X-Accel-Buffering'] = 'no' + + return response + except Exception as e: + return { + "status": "error", + "message": "An error occurred while executing the simulation analysis. Details: " + str(e), + "result": None, + }, 500 + def validate_payload(payload: dict): # Check if all required keys are present; note # that the audience key is optional @@ -76,13 +82,29 @@ def validate_payload(payload: dict): "relevant_parameters", "relevant_parameter_baseline_values", ] + str_keys = [ + "currency", + "selected_version", + "time_period", + "policy_label", + "region", + ] + dict_keys = [ + "policy", + "impact", + ] + list_keys = ["relevant_parameters", "relevant_parameter_baseline_values"] missing_keys = [key for key in required_keys if key not in payload] if missing_keys: return False, f"Missing required keys: {missing_keys}" - + # Check if all keys are of the right type for key, value in payload.items(): - if not isinstance(value, str): + if key in str_keys and not isinstance(value, str): return False, f"Key '{key}' must be a string" + elif key in dict_keys and not isinstance(value, dict): + return False, f"Key '{key}' must be a dictionary" + elif key in list_keys and not isinstance(value, list): + return False, f"Key '{key}' must be a list" return True, None \ No newline at end of file diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 0c831efc..2fcdae9a 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,11 +1,11 @@ -from flask import Blueprint, request, Response +from flask import Blueprint, request, Response, stream_with_context from policyengine_api.helpers import validate_country from policyengine_api.services.tracer_analysis_service import TracerAnalysisService tracer_analysis_bp = Blueprint("tracer_analysis", __name__) tracer_analysis_service = TracerAnalysisService() -@tracer_analysis_bp.route("/", methods=["POST"]) +@tracer_analysis_bp.route("", methods=["POST"]) def execute_tracer_analysis(country_id): # Validate country ID @@ -24,24 +24,31 @@ def execute_tracer_analysis(country_id): variable = payload.get("variable") try: - analysis = tracer_analysis_service.execute_analysis( - country_id, - household_id, - policy_id, - variable, + # Create streaming response + response = Response( + stream_with_context( + tracer_analysis_service.execute_analysis( + country_id, + household_id, + policy_id, + variable, + ) + ), + status=200, ) - - return Response(status=200, response=analysis) + + # Set header to prevent buffering on Google App Engine deployment + # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) + response.headers['X-Accel-Buffering'] = 'no' + + return response except Exception as e: - return ( - dict( - status="error", - message="An error occurred while executing the tracer analysis. Details: " + return { + "status": "error", + "message": "An error occurred while executing the tracer analysis. Details: " + str(e), - result=None, - ), - 500, - ) + "result": None, + }, 500 def validate_payload(payload: dict): # Validate payload diff --git a/policyengine_api/services/analysis_service.py b/policyengine_api/services/ai_analysis_service.py similarity index 94% rename from policyengine_api/services/analysis_service.py rename to policyengine_api/services/ai_analysis_service.py index 14714da7..8e71e2ee 100644 --- a/policyengine_api/services/analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -12,7 +12,7 @@ class AIAnalysisService: local database table """ - def get_existing_analysis(prompt: str) -> Generator[str, None, None] | None: + def get_existing_analysis(self, prompt: str) -> Generator[str, None, None] | None: """ Get existing analysis from the local database """ @@ -42,7 +42,7 @@ def generate(): return generate() - def trigger_ai_analysis(prompt: str) -> Generator[str, None, None]: + def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: # Configure a Claude client claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index da5352a3..ed5c62bc 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -1,8 +1,8 @@ import json -from policyengine_api.services.analysis_service import AnalysisService +from policyengine_api.services.ai_analysis_service import AIAnalysisService -class SimulationAnalysisService(AnalysisService): +class SimulationAnalysisService(AIAnalysisService): """ Service for generating AI analysis of economy-wide simulation runs; this is connected with the simulation_analysis route and @@ -17,12 +17,12 @@ def execute_analysis( currency: str, selected_version: str, time_period: str, - impact: str, + impact: dict, policy_label: str, - policy: str, + policy: dict, region: str, - relevant_parameters: str, - relevant_parameter_baseline_values: str, + relevant_parameters: list, + relevant_parameter_baseline_values: list, audience: str | None, ): @@ -65,6 +65,7 @@ def execute_analysis( raise e def _generate_simulation_analysis_prompt( + self, time_period, region, currency, diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index f6eb6d22..19e1fd98 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -1,13 +1,12 @@ from policyengine_api.data import local_database import json -from flask import stream_with_context from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS from typing import Generator import re import anthropic -from policyengine_api.services.analysis_service import AnalysisService +from policyengine_api.services.ai_analysis_service import AIAnalysisService -class TracerAnalysisService(AnalysisService): +class TracerAnalysisService(AIAnalysisService): def __init__(self): super().__init__() @@ -48,12 +47,12 @@ def execute_analysis( # streaming response existing_analysis: Generator = self.get_existing_analysis(prompt) if existing_analysis is not None: - return stream_with_context(existing_analysis) + return existing_analysis # Otherwise, pass prompt to Claude, then return streaming function try: analysis: Generator = self.trigger_ai_analysis(prompt) - return stream_with_context(analysis) + return analysis except Exception as e: print(f"Error generating AI analysis within tracer analysis service: {str(e)}") raise e From e508b5d26e9bd7a06fbeb59eba755e3266f81598 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 Nov 2024 00:21:44 +0100 Subject: [PATCH 05/14] chore: Lint and changelog --- changelog_entry.yaml | 5 + policyengine_api/api.py | 14 +- .../routes/simulation_analysis_routes.py | 19 +- .../routes/tracer_analysis_routes.py | 26 ++- .../services/ai_analysis_service.py | 13 +- policyengine_api/services/economy_service.py | 1 + policyengine_api/services/job_service.py | 1 + policyengine_api/services/policy_service.py | 1 + .../services/simulation_analysis_service.py | 146 +++++++------ .../services/tracer_analysis_service.py | 205 +++++++++--------- 10 files changed, 235 insertions(+), 196 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..d1bc7e4e 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + changed: + - Refactored AI endpoints to match new routes/services/jobs architecture + - Disabled default buffering on App Engine deployments for AI endpoints \ No newline at end of file diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 162fa00d..3f3bd725 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -14,7 +14,9 @@ # Endpoints from policyengine_api.routes.economy_routes import economy_bp -from policyengine_api.routes.simulation_analysis_routes import simulation_analysis_bp +from policyengine_api.routes.simulation_analysis_routes import ( + simulation_analysis_bp, +) from policyengine_api.routes.tracer_analysis_routes import tracer_analysis_bp from .endpoints import ( get_home, @@ -95,9 +97,8 @@ # Routes for AI analysis of economy microsim runs app.register_blueprint( - simulation_analysis_bp, - url_prefix="//simulation-analysis" - ) + simulation_analysis_bp, url_prefix="//simulation-analysis" +) app.route("//user-policy", methods=["POST"])(set_user_policy) @@ -115,7 +116,10 @@ app.route("/simulations", methods=["GET"])(get_simulations) -app.register_blueprint(tracer_analysis_bp, url_prefix="//tracer-analysis") +app.register_blueprint( + tracer_analysis_bp, url_prefix="//tracer-analysis" +) + @app.route("/liveness-check", methods=["GET"]) def liveness_check(): diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index da7dfe77..73ad48fd 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,10 +1,13 @@ from flask import Blueprint, request, Response, stream_with_context from policyengine_api.helpers import validate_country -from policyengine_api.services.simulation_analysis_service import SimulationAnalysisService +from policyengine_api.services.simulation_analysis_service import ( + SimulationAnalysisService, +) simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() + @simulation_analysis_bp.route("", methods=["POST"]) def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") @@ -20,7 +23,9 @@ def execute_simulation_analysis(country_id): is_payload_valid, message = validate_payload(payload) if not is_payload_valid: - return Response(status=400, response=f"Invalid JSON data; details: {message}") + return Response( + status=400, response=f"Invalid JSON data; details: {message}" + ) currency: str = payload.get("currency") selected_version: str = payload.get("selected_version") @@ -58,16 +63,18 @@ def execute_simulation_analysis(country_id): # Set header to prevent buffering on Google App Engine deployment # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) - response.headers['X-Accel-Buffering'] = 'no' + response.headers["X-Accel-Buffering"] = "no" return response except Exception as e: return { "status": "error", - "message": "An error occurred while executing the simulation analysis. Details: " + str(e), + "message": "An error occurred while executing the simulation analysis. Details: " + + str(e), "result": None, }, 500 + def validate_payload(payload: dict): # Check if all required keys are present; note # that the audience key is optional @@ -97,7 +104,7 @@ def validate_payload(payload: dict): missing_keys = [key for key in required_keys if key not in payload] if missing_keys: return False, f"Missing required keys: {missing_keys}" - + # Check if all keys are of the right type for key, value in payload.items(): if key in str_keys and not isinstance(value, str): @@ -107,4 +114,4 @@ def validate_payload(payload: dict): elif key in list_keys and not isinstance(value, list): return False, f"Key '{key}' must be a list" - return True, None \ No newline at end of file + return True, None diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 2fcdae9a..4194ac98 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,10 +1,13 @@ from flask import Blueprint, request, Response, stream_with_context from policyengine_api.helpers import validate_country -from policyengine_api.services.tracer_analysis_service import TracerAnalysisService +from policyengine_api.services.tracer_analysis_service import ( + TracerAnalysisService, +) tracer_analysis_bp = Blueprint("tracer_analysis", __name__) tracer_analysis_service = TracerAnalysisService() + @tracer_analysis_bp.route("", methods=["POST"]) def execute_tracer_analysis(country_id): @@ -17,7 +20,9 @@ def execute_tracer_analysis(country_id): is_payload_valid, message = validate_payload(payload) if not is_payload_valid: - return Response(status=400, response=f"Invalid JSON data; details: {message}") + return Response( + status=400, response=f"Invalid JSON data; details: {message}" + ) household_id = payload.get("household_id") policy_id = payload.get("policy_id") @@ -36,20 +41,21 @@ def execute_tracer_analysis(country_id): ), status=200, ) - + # Set header to prevent buffering on Google App Engine deployment # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) - response.headers['X-Accel-Buffering'] = 'no' - + response.headers["X-Accel-Buffering"] = "no" + return response except Exception as e: return { - "status": "error", - "message": "An error occurred while executing the tracer analysis. Details: " - + str(e), - "result": None, + "status": "error", + "message": "An error occurred while executing the tracer analysis. Details: " + + str(e), + "result": None, }, 500 + def validate_payload(payload: dict): # Validate payload if not payload: @@ -60,4 +66,4 @@ def validate_payload(payload: dict): if key not in payload: return False, f"Missing required key: {key}" - return True, None \ No newline at end of file + return True, None diff --git a/policyengine_api/services/ai_analysis_service.py b/policyengine_api/services/ai_analysis_service.py index 8e71e2ee..fd72c932 100644 --- a/policyengine_api/services/ai_analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -5,6 +5,7 @@ from typing import Generator from policyengine_api.data import local_database + class AIAnalysisService: """ Base class for various AI analysis-based services, @@ -12,7 +13,9 @@ class AIAnalysisService: local database table """ - def get_existing_analysis(self, prompt: str) -> Generator[str, None, None] | None: + def get_existing_analysis( + self, prompt: str + ) -> Generator[str, None, None] | None: """ Get existing analysis from the local database """ @@ -43,9 +46,11 @@ def generate(): return generate() def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: - + # Configure a Claude client - claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + claude_client = anthropic.Anthropic( + api_key=os.getenv("ANTHROPIC_API_KEY") + ) def generate(): chunk_size = 5 @@ -84,4 +89,4 @@ def generate(): (prompt, response_text, "ok"), ) - return generate() \ No newline at end of file + return generate() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index da7bd34c..2806506b 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -18,6 +18,7 @@ class EconomyService: to the /economy route, which does not have its own table; therefore, it connects with other services to access their respective tables """ + def get_economic_impact( self, country_id, diff --git a/policyengine_api/services/job_service.py b/policyengine_api/services/job_service.py index f8d61c69..721e44ec 100644 --- a/policyengine_api/services/job_service.py +++ b/policyengine_api/services/job_service.py @@ -24,6 +24,7 @@ class JobService(metaclass=Singleton): jobs. This is not connected to any routes or tables, but interfaces with the Redis queue to enqueue jobs and track their status. """ + def __init__(self): self.recent_jobs = {} diff --git a/policyengine_api/services/policy_service.py b/policyengine_api/services/policy_service.py index 4367ff25..0a820f83 100644 --- a/policyengine_api/services/policy_service.py +++ b/policyengine_api/services/policy_service.py @@ -7,6 +7,7 @@ class PolicyService: this will be connected to the /policy route and is partially connected to the policy database table """ + def get_policy_json(self, country_id, policy_id): try: policy_json = database.query( diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index ed5c62bc..99aff720 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -2,37 +2,74 @@ from policyengine_api.services.ai_analysis_service import AIAnalysisService + class SimulationAnalysisService(AIAnalysisService): - """ - Service for generating AI analysis of economy-wide simulation - runs; this is connected with the simulation_analysis route and - analysis database table - """ - def __init__(self): - super().__init__() - - def execute_analysis( - self, - country_id: str, - currency: str, - selected_version: str, - time_period: str, - impact: dict, - policy_label: str, - policy: dict, - region: str, - relevant_parameters: list, - relevant_parameter_baseline_values: list, - audience: str | None, - ): - - # Check if the region is enhanced_cps - is_enhanced_cps = "enhanced_cps" in region - - print("Generating prompt for economy-wide simulation analysis") - - # Create prompt based on data - prompt = self._generate_simulation_analysis_prompt( + """ + Service for generating AI analysis of economy-wide simulation + runs; this is connected with the simulation_analysis route and + analysis database table + """ + + def __init__(self): + super().__init__() + + def execute_analysis( + self, + country_id: str, + currency: str, + selected_version: str, + time_period: str, + impact: dict, + policy_label: str, + policy: dict, + region: str, + relevant_parameters: list, + relevant_parameter_baseline_values: list, + audience: str | None, + ): + + # Check if the region is enhanced_cps + is_enhanced_cps = "enhanced_cps" in region + + print("Generating prompt for economy-wide simulation analysis") + + # Create prompt based on data + prompt = self._generate_simulation_analysis_prompt( + time_period, + region, + currency, + policy, + impact, + relevant_parameters, + relevant_parameter_baseline_values, + is_enhanced_cps, + selected_version, + country_id, + policy_label, + ) + + # Add audience description to end + prompt += self.audience_descriptions[audience] + + print("Checking if AI analysis already exists for this prompt") + # If a calculated record exists for this prompt, return it as a + # streaming response + existing_analysis = self.get_existing_analysis(prompt) + if existing_analysis is not None: + return existing_analysis + + print( + "Found no existing AI analysis; triggering new analysis with Claude" + ) + # Otherwise, pass prompt to Claude, then return streaming function + try: + analysis = self.trigger_ai_analysis(prompt) + return analysis + except Exception as e: + raise e + + def _generate_simulation_analysis_prompt( + self, time_period, region, currency, @@ -44,41 +81,8 @@ def execute_analysis( selected_version, country_id, policy_label, - ) - - # Add audience description to end - prompt += self.audience_descriptions[audience] - - print("Checking if AI analysis already exists for this prompt") - # If a calculated record exists for this prompt, return it as a - # streaming response - existing_analysis = self.get_existing_analysis(prompt) - if existing_analysis is not None: - return existing_analysis - - print("Found no existing AI analysis; triggering new analysis with Claude") - # Otherwise, pass prompt to Claude, then return streaming function - try: - analysis = self.trigger_ai_analysis(prompt) - return analysis - except Exception as e: - raise e - - def _generate_simulation_analysis_prompt( - self, - time_period, - region, - currency, - policy, - impact, - relevant_parameters, - relevant_parameter_baseline_values, - is_enhanced_cps, - selected_version, - country_id, - policy_label, - ): - return f""" + ): + return f""" I'm using PolicyEngine, a free, open source tool to compute the impact of public policy. I'm writing up an economic analysis of a hypothetical tax-benefit policy reform. Please write the analysis for me using the details below, in @@ -184,9 +188,9 @@ def _generate_simulation_analysis_prompt( impact["inequality"], )} """ - - audience_descriptions = { - "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", - "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", - "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", - } \ No newline at end of file + + audience_descriptions = { + "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", + "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", + "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", + } diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index 19e1fd98..162245d2 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -6,114 +6,119 @@ import anthropic from policyengine_api.services.ai_analysis_service import AIAnalysisService + class TracerAnalysisService(AIAnalysisService): - def __init__(self): - super().__init__() - - def execute_analysis( - self, - country_id: str, - household_id: str, - policy_id: str, - variable: str, - ): - - api_version = COUNTRY_PACKAGE_VERSIONS[country_id] - - # Retrieve tracer record from table - try: - tracer: list[str] = self.get_tracer( - country_id, - household_id, - policy_id, - api_version, - ) - except Exception as e: - raise e - - # Parse the tracer output for our given variable - try: - tracer_segment: list[str] = self._parse_tracer_output(tracer, variable) - except Exception as e: - print(f"Error parsing tracer output: {str(e)}") - raise e - - # Add the parsed tracer output to the prompt - prompt = self.prompt_template.format( - variable=variable, tracer_segment=tracer_segment - ) - - # If a calculated record exists for this prompt, return it as a - # streaming response - existing_analysis: Generator = self.get_existing_analysis(prompt) - if existing_analysis is not None: - return existing_analysis - - # Otherwise, pass prompt to Claude, then return streaming function - try: - analysis: Generator = self.trigger_ai_analysis(prompt) - return analysis - except Exception as e: - print(f"Error generating AI analysis within tracer analysis service: {str(e)}") - raise e - - def get_tracer( + def __init__(self): + super().__init__() + + def execute_analysis( + self, + country_id: str, + household_id: str, + policy_id: str, + variable: str, + ): + + api_version = COUNTRY_PACKAGE_VERSIONS[country_id] + + # Retrieve tracer record from table + try: + tracer: list[str] = self.get_tracer( + country_id, + household_id, + policy_id, + api_version, + ) + except Exception as e: + raise e + + # Parse the tracer output for our given variable + try: + tracer_segment: list[str] = self._parse_tracer_output( + tracer, variable + ) + except Exception as e: + print(f"Error parsing tracer output: {str(e)}") + raise e + + # Add the parsed tracer output to the prompt + prompt = self.prompt_template.format( + variable=variable, tracer_segment=tracer_segment + ) + + # If a calculated record exists for this prompt, return it as a + # streaming response + existing_analysis: Generator = self.get_existing_analysis(prompt) + if existing_analysis is not None: + return existing_analysis + + # Otherwise, pass prompt to Claude, then return streaming function + try: + analysis: Generator = self.trigger_ai_analysis(prompt) + return analysis + except Exception as e: + print( + f"Error generating AI analysis within tracer analysis service: {str(e)}" + ) + raise e + + def get_tracer( self, country_id: str, household_id: str, policy_id: str, api_version: str, - ) -> list: - try: - # Retrieve from the tracers table in the local database - row = local_database.query( - """ + ) -> list: + try: + # Retrieve from the tracers table in the local database + row = local_database.query( + """ SELECT * FROM tracers WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ? """, - (household_id, policy_id, country_id, api_version), - ).fetchone() - - if row is None: - raise KeyError("No tracer found for this household") - - tracer_output_list = json.loads(row["tracer_output"]) - return tracer_output_list - - except Exception as e: - print(f"Error getting existing tracer analysis: {str(e)}") - raise e - - def _parse_tracer_output(self, tracer_output, target_variable): - result = [] - target_indent = None - capturing = False - - # Create a regex pattern to match the exact variable name - # This will match the variable name followed by optional whitespace, - # then optional angle brackets with any content, then optional whitespace - pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" - - for line in tracer_output: - # Count leading spaces to determine indentation level - indent = len(line) - len(line.strip()) - - # Check if this line matches our target variable - match = re.match(pattern, line) - if match and not capturing: - target_indent = indent - capturing = True - result.append(line) - elif capturing: - # Stop capturing if we encounter a line with less indentation than the target - if indent <= target_indent: - break - # Capture dependencies (lines with greater indentation) - result.append(line) - - return result - - prompt_template = f"""{anthropic.HUMAN_PROMPT} You are an AI assistant explaining US policy calculations. + (household_id, policy_id, country_id, api_version), + ).fetchone() + + if row is None: + raise KeyError("No tracer found for this household") + + tracer_output_list = json.loads(row["tracer_output"]) + return tracer_output_list + + except Exception as e: + print(f"Error getting existing tracer analysis: {str(e)}") + raise e + + def _parse_tracer_output(self, tracer_output, target_variable): + result = [] + target_indent = None + capturing = False + + # Create a regex pattern to match the exact variable name + # This will match the variable name followed by optional whitespace, + # then optional angle brackets with any content, then optional whitespace + pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" + + for line in tracer_output: + # Count leading spaces to determine indentation level + indent = len(line) - len(line.strip()) + + # Check if this line matches our target variable + match = re.match(pattern, line) + if match and not capturing: + target_indent = indent + capturing = True + result.append(line) + elif capturing: + # Stop capturing if we encounter a line with less indentation than the target + if indent <= target_indent: + break + # Capture dependencies (lines with greater indentation) + result.append(line) + + return result + + prompt_template = f"""{anthropic.HUMAN_PROMPT} You are an AI assistant explaining US policy calculations. The user has run a simulation for the variable '{{variable}}'. Here's the tracer output: {{tracer_segment}} @@ -126,4 +131,4 @@ def _parse_tracer_output(self, tracer_output, target_variable): Keep your explanation concise but informative, suitable for a general audience. Do not start with phrases like "Certainly!" or "Here's an explanation. It will be rendered as markdown, so preface $ with \. - {anthropic.AI_PROMPT}""" \ No newline at end of file + {anthropic.AI_PROMPT}""" From e55d424f5344544498a360e7df0a373f996085a7 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 Nov 2024 19:28:24 +0100 Subject: [PATCH 06/14] chore: Update tests and modify code based on test failures --- changelog_entry.yaml | 3 +- policyengine_api/routes/economy_routes.py | 12 +- .../routes/simulation_analysis_routes.py | 22 +- .../routes/tracer_analysis_routes.py | 32 ++- .../services/simulation_analysis_service.py | 6 +- tests/python/test_ai_analysis.py | 27 ++- tests/python/test_simulation_analysis.py | 204 +++++++----------- tests/python/test_tracer.py | 43 ++-- 8 files changed, 169 insertions(+), 180 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index d1bc7e4e..4b991734 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -2,4 +2,5 @@ changes: changed: - Refactored AI endpoints to match new routes/services/jobs architecture - - Disabled default buffering on App Engine deployments for AI endpoints \ No newline at end of file + - Disabled default buffering on App Engine deployments for AI endpoints + - Updated relevant tests \ No newline at end of file diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 7b800d19..37776a2a 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -43,12 +43,12 @@ def get_economic_impact(country_id, policy_id, baseline_policy_id): ) return result except Exception as e: - return ( - dict( - status="error", - message="An error occurred while calculating the economic impact. Details: " + return Response( + { + "status": "error", + "message": "An error occurred while calculating the economic impact. Details: " + str(e), - result=None, - ), + "result": None, + }, 500, ) diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 73ad48fd..eaf20f3d 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,4 +1,5 @@ from flask import Blueprint, request, Response, stream_with_context +import json from policyengine_api.helpers import validate_country from policyengine_api.services.simulation_analysis_service import ( SimulationAnalysisService, @@ -34,8 +35,8 @@ def execute_simulation_analysis(country_id): policy_label: str = payload.get("policy_label") policy: dict = payload.get("policy") region: str = payload.get("region") - relevant_parameters: list = payload.get("relevant_parameters") - relevant_parameter_baseline_values: list = payload.get( + relevant_parameters: list[dict] = payload.get("relevant_parameters") + relevant_parameter_baseline_values: list[dict] = payload.get( "relevant_parameter_baseline_values" ) audience = payload.get("audience", "") @@ -67,12 +68,17 @@ def execute_simulation_analysis(country_id): return response except Exception as e: - return { - "status": "error", - "message": "An error occurred while executing the simulation analysis. Details: " - + str(e), - "result": None, - }, 500 + return Response( + json.dumps( + { + "status": "error", + "message": "An error occurred while executing the simulation analysis. Details: " + + str(e), + "result": None, + } + ), + status=500, + ) def validate_payload(payload: dict): diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 4194ac98..92fbafbe 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -3,6 +3,7 @@ from policyengine_api.services.tracer_analysis_service import ( TracerAnalysisService, ) +import json tracer_analysis_bp = Blueprint("tracer_analysis", __name__) tracer_analysis_service = TracerAnalysisService() @@ -47,13 +48,32 @@ def execute_tracer_analysis(country_id): response.headers["X-Accel-Buffering"] = "no" return response + except KeyError as e: + """ + This exception is raised when the tracer can't find a household tracer record + """ + return Response( + json.dumps( + { + "status": "not found", + "message": "No household simulation tracer found", + "result": None, + } + ), + 404, + ) except Exception as e: - return { - "status": "error", - "message": "An error occurred while executing the tracer analysis. Details: " - + str(e), - "result": None, - }, 500 + return Response( + json.dumps( + { + "status": "error", + "message": "An error occurred while executing the tracer analysis. Details: " + + str(e), + "result": None, + } + ), + 500, + ) def validate_payload(payload: dict): diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index 99aff720..2e5c31b2 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -23,13 +23,13 @@ def execute_analysis( policy_label: str, policy: dict, region: str, - relevant_parameters: list, - relevant_parameter_baseline_values: list, + relevant_parameters: list[dict], + relevant_parameter_baseline_values: list[dict], audience: str | None, ): # Check if the region is enhanced_cps - is_enhanced_cps = "enhanced_cps" in region + is_enhanced_cps = "enhanced_us" in region print("Generating prompt for economy-wide simulation analysis") diff --git a/tests/python/test_ai_analysis.py b/tests/python/test_ai_analysis.py index 9f7bec68..5df4adbf 100644 --- a/tests/python/test_ai_analysis.py +++ b/tests/python/test_ai_analysis.py @@ -2,14 +2,13 @@ from unittest.mock import patch, MagicMock import json import os -from policyengine_api.utils.ai_analysis import ( - trigger_ai_analysis, - get_existing_analysis, -) +from policyengine_api.services.ai_analysis_service import AIAnalysisService +test_ai_service = AIAnalysisService() -@patch("policyengine_api.utils.ai_analysis.anthropic.Anthropic") -@patch("policyengine_api.utils.ai_analysis.local_database") + +@patch("policyengine_api.services.ai_analysis_service.anthropic.Anthropic") +@patch("policyengine_api.services.ai_analysis_service.local_database") def test_trigger_ai_analysis(mock_db, mock_anthropic): mock_client = MagicMock() mock_anthropic.return_value = mock_client @@ -20,7 +19,7 @@ def test_trigger_ai_analysis(mock_db, mock_anthropic): ) prompt = "Test prompt" - generator = trigger_ai_analysis(prompt) + generator = test_ai_service.trigger_ai_analysis(prompt) # Check initial yield initial_data = json.loads(next(generator)) @@ -47,15 +46,15 @@ def test_trigger_ai_analysis(mock_db, mock_anthropic): ) -@patch("policyengine_api.utils.ai_analysis.local_database") -@patch("policyengine_api.utils.ai_analysis.time.sleep") +@patch("policyengine_api.services.ai_analysis_service.local_database") +@patch("policyengine_api.services.ai_analysis_service.time.sleep") def test_get_existing_analysis_found(mock_sleep, mock_db): mock_db.query.return_value.fetchone.return_value = { "analysis": "Existing analysis" } prompt = "Test prompt" - generator = get_existing_analysis(prompt) + generator = test_ai_service.get_existing_analysis(prompt) # Check initial yield initial_data = json.loads(next(generator)) @@ -76,12 +75,12 @@ def test_get_existing_analysis_found(mock_sleep, mock_db): assert mock_sleep.call_count == 4 -@patch("policyengine_api.utils.ai_analysis.local_database") +@patch("policyengine_api.services.ai_analysis_service.local_database") def test_get_existing_analysis_not_found(mock_db): mock_db.query.return_value.fetchone.return_value = None prompt = "Test prompt" - result = get_existing_analysis(prompt) + result = test_ai_service.get_existing_analysis(prompt) assert result is None mock_db.query.assert_called_once_with( @@ -97,14 +96,14 @@ def test_anthropic_api_key(): # Test error handling in trigger_ai_analysis -@patch("policyengine_api.utils.ai_analysis.anthropic.Anthropic") +@patch("policyengine_api.services.ai_analysis_service.anthropic.Anthropic") def test_trigger_ai_analysis_error(mock_anthropic): mock_client = MagicMock() mock_anthropic.return_value = mock_client mock_client.messages.stream.side_effect = Exception("API Error") prompt = "Test prompt" - generator = trigger_ai_analysis(prompt) + generator = test_ai_service.trigger_ai_analysis(prompt) # Check initial yield initial_data = json.loads(next(generator)) diff --git a/tests/python/test_simulation_analysis.py b/tests/python/test_simulation_analysis.py index f63dd86d..2b0be2f0 100644 --- a/tests/python/test_simulation_analysis.py +++ b/tests/python/test_simulation_analysis.py @@ -1,8 +1,15 @@ import pytest -from unittest.mock import patch, MagicMock -from flask import Flask, jsonify -from policyengine_api.data import local_database -from policyengine_api.endpoints import execute_simulation_analysis +from unittest.mock import patch +from flask import Flask + +from policyengine_api.services.simulation_analysis_service import ( + SimulationAnalysisService, +) +from policyengine_api.routes.simulation_analysis_routes import ( + execute_simulation_analysis, +) + +test_service = SimulationAnalysisService() @pytest.fixture @@ -12,126 +19,75 @@ def app(): return app +test_impact = { + "budget": 1000, + "intra_decile": 0.1, + "decile": 0.2, + "poverty": { + "poverty": 0.3, + "deep_poverty": 0.4, + }, + "poverty_by_gender": 0.5, + "poverty_by_race": {"poverty": 0.6}, + "inequality": 0.7, +} + +test_json = { + "currency": "USD", + "selected_version": "2023", + "time_period": "2023", + "impact": test_impact, + "policy_label": "Test Policy", + "policy": dict(policy_json="policy details"), + "region": "US", + "relevant_parameters": ["param1", "param2"], + "relevant_parameter_baseline_values": [ + {"param1": 100}, + {"param2": 200}, + ], + "audience": "Normal", +} + + def test_execute_simulation_analysis_existing_analysis(app, rest_client): - test_impact = { - "budget": 1000, - "intra_decile": 0.1, - "decile": 0.2, - "poverty": { - "poverty": 0.3, - "deep_poverty": 0.4, - }, - "poverty_by_gender": 0.5, - "poverty_by_race": {"poverty": 0.6}, - "inequality": 0.7, - } - with app.test_request_context( - json={ - "currency": "USD", - "selected_version": "2023", - "time_period": "2023", - "impact": test_impact, - "policy_label": "Test Policy", - "policy": "policy details", - "region": "US", - "relevant_parameters": ["param1", "param2"], - "relevant_parameter_baseline_values": { - "param1": 100, - "param2": 200, - }, - "audience": "Normal", - } - ): + with app.test_request_context(json=test_json): with patch( - "policyengine_api.endpoints.simulation_analysis.get_existing_analysis" + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" ) as mock_get_existing: mock_get_existing.return_value = (s for s in ["Existing analysis"]) - response = execute_simulation_analysis("US") + response = execute_simulation_analysis("us") assert response.status_code == 200 assert b"Existing analysis" in response.data def test_execute_simulation_analysis_new_analysis(app, rest_client): - test_impact = { - "budget": 1000, - "intra_decile": 0.1, - "decile": 0.2, - "poverty": { - "poverty": 0.3, - "deep_poverty": 0.4, - }, - "poverty_by_gender": 0.5, - "poverty_by_race": {"poverty": 0.6}, - "inequality": 0.7, - } - with app.test_request_context( - json={ - "currency": "USD", - "selected_version": "2023", - "time_period": "2023", - "impact": test_impact, - "policy_label": "Test Policy", - "policy": "policy details", - "region": "US", - "relevant_parameters": ["param1", "param2"], - "relevant_parameter_baseline_values": { - "param1": 100, - "param2": 200, - }, - "audience": "Normal", - } - ): + with app.test_request_context(json=test_json): with patch( - "policyengine_api.endpoints.simulation_analysis.get_existing_analysis" + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" ) as mock_get_existing: mock_get_existing.return_value = None with patch( - "policyengine_api.endpoints.simulation_analysis.trigger_ai_analysis" + "policyengine_api.services.simulation_analysis_service.AIAnalysisService.trigger_ai_analysis" ) as mock_trigger: mock_trigger.return_value = (s for s in ["New analysis"]) - response = execute_simulation_analysis("US") + response = execute_simulation_analysis("us") assert response.status_code == 200 assert b"New analysis" in response.data def test_execute_simulation_analysis_error(app, rest_client): - test_impact = { - "budget": 1000, - "intra_decile": 0.1, - "decile": 0.2, - "poverty": {"poverty": 0.3, "deep_poverty": 0.4}, - "poverty_by_gender": 0.5, - "poverty_by_race": {"poverty": 0.6}, - "inequality": 0.7, - } - with app.test_request_context( - json={ - "currency": "USD", - "selected_version": "2023", - "time_period": "2023", - "impact": test_impact, - "policy_label": "Test Policy", - "policy": "policy details", - "region": "US", - "relevant_parameters": ["param1", "param2"], - "relevant_parameter_baseline_values": { - "param1": 100, - "param2": 200, - }, - "audience": "Normal", - } - ): + with app.test_request_context(json=test_json): with patch( - "policyengine_api.endpoints.simulation_analysis.get_existing_analysis" + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" ) as mock_get_existing: mock_get_existing.return_value = None with patch( - "policyengine_api.endpoints.simulation_analysis.trigger_ai_analysis" + "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" ) as mock_trigger: mock_trigger.side_effect = Exception("Test error") @@ -141,60 +97,52 @@ def test_execute_simulation_analysis_error(app, rest_client): def test_execute_simulation_analysis_enhanced_cps(app, rest_client): - test_impact = { - "budget": 1000, - "intra_decile": 0.1, - "decile": 0.2, - "poverty": {"poverty": 0.3, "deep_poverty": 0.4}, - "poverty_by_gender": 0.5, - "poverty_by_race": {"poverty": 0.6}, - "inequality": 0.7, + policy_details = dict(policy_json="policy details") + + test_json_enhanced_us = { + "currency": "USD", + "selected_version": "2023", + "time_period": "2023", + "impact": test_impact, + "policy_label": "Test Policy", + "policy": policy_details, + "region": "enhanced_us", + "relevant_parameters": ["param1", "param2"], + "relevant_parameter_baseline_values": [ + {"param1": 100}, + {"param2": 200}, + ], + "audience": "Normal", } - with app.test_request_context( - json={ - "currency": "USD", - "selected_version": "2023", - "time_period": "2023", - "impact": test_impact, - "policy_label": "Test Policy", - "policy": "policy details", - "region": "enhanced_cps_US", - "relevant_parameters": ["param1", "param2"], - "relevant_parameter_baseline_values": { - "param1": 100, - "param2": 200, - }, - "audience": "Normal", - } - ): + with app.test_request_context(json=test_json_enhanced_us): with patch( - "policyengine_api.endpoints.simulation_analysis.generate_simulation_analysis_prompt" + "policyengine_api.services.simulation_analysis_service.SimulationAnalysisService._generate_simulation_analysis_prompt" ) as mock_generate_prompt: with patch( - "policyengine_api.endpoints.simulation_analysis.get_existing_analysis" + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" ) as mock_get_existing: mock_get_existing.return_value = None with patch( - "policyengine_api.endpoints.simulation_analysis.trigger_ai_analysis" + "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" ) as mock_trigger: mock_trigger.return_value = ( s for s in ["Enhanced CPS analysis"] ) - response = execute_simulation_analysis("US") + response = execute_simulation_analysis("us") assert response.status_code == 200 assert b"Enhanced CPS analysis" in response.data mock_generate_prompt.assert_called_once_with( "2023", - "enhanced_cps_US", + "enhanced_us", "USD", - "policy details", + policy_details, test_impact, ["param1", "param2"], - {"param1": 100, "param2": 200}, + [{"param1": 100}, {"param2": 200}], True, "2023", - "US", + "us", "Test Policy", ) diff --git a/tests/python/test_tracer.py b/tests/python/test_tracer.py index bd842ee0..2e08bbaf 100644 --- a/tests/python/test_tracer.py +++ b/tests/python/test_tracer.py @@ -1,9 +1,15 @@ import pytest from flask import Flask, json -from unittest.mock import patch, MagicMock -from policyengine_api.endpoints import execute_tracer_analysis -from policyengine_api.utils.tracer_analysis import parse_tracer_output -from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS +from unittest.mock import patch + +from policyengine_api.routes.tracer_analysis_routes import ( + execute_tracer_analysis, +) +from policyengine_api.services.tracer_analysis_service import ( + TracerAnalysisService, +) + +test_service = TracerAnalysisService() @pytest.fixture @@ -25,19 +31,25 @@ def test_parse_tracer_output(): " pension_income <500>", ] - result = parse_tracer_output(tracer_output, "only_government_benefit") + result = test_service._parse_tracer_output( + tracer_output, "only_government_benefit" + ) assert result == tracer_output - result = parse_tracer_output(tracer_output, "market_income") + result = test_service._parse_tracer_output(tracer_output, "market_income") assert result == tracer_output[1:4] - result = parse_tracer_output(tracer_output, "non_market_income") + result = test_service._parse_tracer_output( + tracer_output, "non_market_income" + ) assert result == tracer_output[4:] # Test cases for execute_tracer_analysis function -@patch("policyengine_api.endpoints.tracer_analysis.local_database") -@patch("policyengine_api.endpoints.tracer_analysis.trigger_ai_analysis") +@patch("policyengine_api.services.tracer_analysis_service.local_database") +@patch( + "policyengine_api.services.tracer_analysis_service.TracerAnalysisService.trigger_ai_analysis" +) def test_execute_tracer_analysis_success( mock_trigger_ai_analysis, mock_db, app, rest_client ): @@ -66,7 +78,7 @@ def test_execute_tracer_analysis_success( assert b"AI analysis result" in response.data -@patch("policyengine_api.endpoints.tracer_analysis.local_database") +@patch("policyengine_api.services.tracer_analysis_service.local_database") def test_execute_tracer_analysis_no_tracer(mock_db, app, rest_client): mock_db.query.return_value.fetchone.return_value = None @@ -82,12 +94,15 @@ def test_execute_tracer_analysis_no_tracer(mock_db, app, rest_client): assert response.status_code == 404 assert ( - "no household simulation tracer found" in response.response["message"] + "No household simulation tracer found" + in json.loads(response.data)["message"] ) -@patch("policyengine_api.endpoints.tracer_analysis.local_database") -@patch("policyengine_api.endpoints.tracer_analysis.trigger_ai_analysis") +@patch("policyengine_api.services.tracer_analysis_service.local_database") +@patch( + "policyengine_api.services.tracer_analysis_service.TracerAnalysisService.trigger_ai_analysis" +) def test_execute_tracer_analysis_ai_error( mock_trigger_ai_analysis, mock_db, app, rest_client ): @@ -114,7 +129,7 @@ def test_execute_tracer_analysis_ai_error( response = execute_tracer_analysis("us") assert response.status_code == 500 - assert "Error computing analysis" in response.response["message"] + assert "An error occurred" in json.loads(response.data)["message"] # Test invalid country From ae581ec4a603c604f9e594c733302cdeb960a00d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 17:44:02 +0100 Subject: [PATCH 07/14] fix: Separate out payload validator --- .../routes/simulation_analysis_routes.py | 43 +---------- policyengine_api/utils/payload_validators.py | 40 ++++++++++ tests/python/test_payload_validators.py | 77 +++++++++++++++++++ 3 files changed, 118 insertions(+), 42 deletions(-) create mode 100644 policyengine_api/utils/payload_validators.py create mode 100644 tests/python/test_payload_validators.py diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index eaf20f3d..f80b500e 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -4,6 +4,7 @@ from policyengine_api.services.simulation_analysis_service import ( SimulationAnalysisService, ) +from policyengine_api.utils.payload_validators import validate_sim_analysis_payload as validate_payload simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() @@ -79,45 +80,3 @@ def execute_simulation_analysis(country_id): ), status=500, ) - - -def validate_payload(payload: dict): - # Check if all required keys are present; note - # that the audience key is optional - required_keys = [ - "currency", - "selected_version", - "time_period", - "impact", - "policy_label", - "policy", - "region", - "relevant_parameters", - "relevant_parameter_baseline_values", - ] - str_keys = [ - "currency", - "selected_version", - "time_period", - "policy_label", - "region", - ] - dict_keys = [ - "policy", - "impact", - ] - list_keys = ["relevant_parameters", "relevant_parameter_baseline_values"] - missing_keys = [key for key in required_keys if key not in payload] - if missing_keys: - return False, f"Missing required keys: {missing_keys}" - - # Check if all keys are of the right type - for key, value in payload.items(): - if key in str_keys and not isinstance(value, str): - return False, f"Key '{key}' must be a string" - elif key in dict_keys and not isinstance(value, dict): - return False, f"Key '{key}' must be a dictionary" - elif key in list_keys and not isinstance(value, list): - return False, f"Key '{key}' must be a list" - - return True, None diff --git a/policyengine_api/utils/payload_validators.py b/policyengine_api/utils/payload_validators.py new file mode 100644 index 00000000..cac22bba --- /dev/null +++ b/policyengine_api/utils/payload_validators.py @@ -0,0 +1,40 @@ +def validate_sim_analysis_payload(payload: dict): + # Check if all required keys are present; note + # that the audience key is optional + required_keys = [ + "currency", + "selected_version", + "time_period", + "impact", + "policy_label", + "policy", + "region", + "relevant_parameters", + "relevant_parameter_baseline_values", + ] + str_keys = [ + "currency", + "selected_version", + "time_period", + "policy_label", + "region", + ] + dict_keys = [ + "policy", + "impact", + ] + list_keys = ["relevant_parameters", "relevant_parameter_baseline_values"] + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + return False, f"Missing required keys: {missing_keys}" + + # Check if all keys are of the right type + for key, value in payload.items(): + if key in str_keys and not isinstance(value, str): + return False, f"Key '{key}' must be a string" + elif key in dict_keys and not isinstance(value, dict): + return False, f"Key '{key}' must be a dictionary" + elif key in list_keys and not isinstance(value, list): + return False, f"Key '{key}' must be a list" + + return True, None \ No newline at end of file diff --git a/tests/python/test_payload_validators.py b/tests/python/test_payload_validators.py new file mode 100644 index 00000000..63d82cd9 --- /dev/null +++ b/tests/python/test_payload_validators.py @@ -0,0 +1,77 @@ +import pytest +from typing import Dict, Any, Tuple + +from policyengine_api.utils.payload_validators import validate_sim_analysis_payload + +@pytest.fixture +def valid_payload() -> Dict[str, Any]: + return { + "currency": "USD", + "selected_version": "v1.0", + "time_period": "2024", + "impact": {"value": 100}, + "policy_label": "Test Policy", + "policy": {"type": "tax", "rate": 0.1}, + "region": "NA", + "relevant_parameters": ["param1", "param2"], + "relevant_parameter_baseline_values": [1.0, 2.0], + } + +def test_valid_payload(valid_payload): + """Test that a valid payload passes validation""" + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is True + assert error is None + +def test_missing_required_key(valid_payload): + """Test that missing required keys are detected""" + del valid_payload["currency"] + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is False + assert "Missing required keys: ['currency']" in error + +def test_invalid_string_type(valid_payload): + """Test that wrong type for string fields is detected""" + valid_payload["currency"] = 123 # Should be string + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is False + assert "Key 'currency' must be a string" in error + +def test_invalid_dict_type(valid_payload): + """Test that wrong type for dictionary fields is detected""" + valid_payload["impact"] = ["not", "a", "dict"] # Should be dict + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is False + assert "Key 'impact' must be a dictionary" in error + +def test_invalid_list_type(valid_payload): + """Test that wrong type for list fields is detected""" + valid_payload["relevant_parameters"] = "not a list" # Should be list + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is False + assert "Key 'relevant_parameters' must be a list" in error + +def test_extra_keys_allowed(valid_payload): + """Test that extra keys don't cause validation to fail""" + valid_payload["extra_key"] = "some value" + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is True + assert error is None + +@pytest.mark.parametrize("key", [ + "currency", + "selected_version", + "time_period", + "impact", + "policy_label", + "policy", + "region", + "relevant_parameters", + "relevant_parameter_baseline_values" +]) +def test_individual_required_keys(valid_payload, key): + """Test that each required key is properly checked""" + del valid_payload[key] + is_valid, error = validate_sim_analysis_payload(valid_payload) + assert is_valid is False + assert f"Missing required keys: ['{key}']" in error \ No newline at end of file From 196c1f45adb78bf555036b89acaccf6ce402a289 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 17:47:12 +0100 Subject: [PATCH 08/14] fix: Migrate validate_country to payload_validators --- policyengine_api/endpoints/household.py | 2 +- policyengine_api/endpoints/metadata.py | 2 +- policyengine_api/endpoints/policy.py | 2 +- policyengine_api/endpoints/user_profile.py | 2 +- policyengine_api/helpers/__init__.py | 1 - policyengine_api/routes/simulation_analysis_routes.py | 6 +++--- policyengine_api/routes/tracer_analysis_routes.py | 2 +- policyengine_api/utils/payload_validators/__init__.py | 2 ++ .../payload_validators}/validate_country.py | 0 .../validate_sim_analysis_payload.py} | 0 10 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 policyengine_api/utils/payload_validators/__init__.py rename policyengine_api/{helpers => utils/payload_validators}/validate_country.py (100%) rename policyengine_api/utils/{payload_validators.py => payload_validators/validate_sim_analysis_payload.py} (100%) diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index 308f17cd..d5d551cf 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -11,7 +11,7 @@ import json import logging from datetime import date -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country def add_yearly_variables(household, country_id): diff --git a/policyengine_api/endpoints/metadata.py b/policyengine_api/endpoints/metadata.py index af47336e..007a484b 100644 --- a/policyengine_api/endpoints/metadata.py +++ b/policyengine_api/endpoints/metadata.py @@ -1,4 +1,4 @@ -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country from policyengine_api.country import COUNTRIES diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index 14f9d144..18fe62f2 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -1,4 +1,4 @@ -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country from policyengine_api.data import database from policyengine_api.utils import hash_object from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS diff --git a/policyengine_api/endpoints/user_profile.py b/policyengine_api/endpoints/user_profile.py index a664fb99..e600bafa 100644 --- a/policyengine_api/endpoints/user_profile.py +++ b/policyengine_api/endpoints/user_profile.py @@ -1,5 +1,5 @@ from flask import Response, request -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country from policyengine_api.data import database import json diff --git a/policyengine_api/helpers/__init__.py b/policyengine_api/helpers/__init__.py index eed55b2a..e09cc904 100644 --- a/policyengine_api/helpers/__init__.py +++ b/policyengine_api/helpers/__init__.py @@ -1,2 +1 @@ -from .validate_country import validate_country from .get_current_law import get_current_law_policy_id diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index f80b500e..27767503 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,10 +1,10 @@ from flask import Blueprint, request, Response, stream_with_context import json -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country from policyengine_api.services.simulation_analysis_service import ( SimulationAnalysisService, ) -from policyengine_api.utils.payload_validators import validate_sim_analysis_payload as validate_payload +from policyengine_api.utils.payload_validators import validate_sim_analysis_payload, validate_country simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() @@ -23,7 +23,7 @@ def execute_simulation_analysis(country_id): # where necessary payload = request.json - is_payload_valid, message = validate_payload(payload) + is_payload_valid, message = validate_sim_analysis_payload(payload) if not is_payload_valid: return Response( status=400, response=f"Invalid JSON data; details: {message}" diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 92fbafbe..594028c1 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,5 +1,5 @@ from flask import Blueprint, request, Response, stream_with_context -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country from policyengine_api.services.tracer_analysis_service import ( TracerAnalysisService, ) diff --git a/policyengine_api/utils/payload_validators/__init__.py b/policyengine_api/utils/payload_validators/__init__.py new file mode 100644 index 00000000..d9c16308 --- /dev/null +++ b/policyengine_api/utils/payload_validators/__init__.py @@ -0,0 +1,2 @@ +from .validate_sim_analysis_payload import validate_sim_analysis_payload +from .validate_country import validate_country \ No newline at end of file diff --git a/policyengine_api/helpers/validate_country.py b/policyengine_api/utils/payload_validators/validate_country.py similarity index 100% rename from policyengine_api/helpers/validate_country.py rename to policyengine_api/utils/payload_validators/validate_country.py diff --git a/policyengine_api/utils/payload_validators.py b/policyengine_api/utils/payload_validators/validate_sim_analysis_payload.py similarity index 100% rename from policyengine_api/utils/payload_validators.py rename to policyengine_api/utils/payload_validators/validate_sim_analysis_payload.py From 9b513af25e27191c7d376d26529173333dd9cbc0 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 18:00:22 +0100 Subject: [PATCH 09/14] chore: Move test fixtures, drop helpers folder in favor of utils --- policyengine_api/helpers/__init__.py | 1 - policyengine_api/routes/economy_routes.py | 6 ++-- policyengine_api/utils/__init__.py | 1 + .../{helpers => utils}/get_current_law.py | 0 .../fixtures/simulation_analysis_fixtures.py | 28 ++++++++++++++++ tests/python/test_simulation_analysis.py | 32 ++----------------- 6 files changed, 33 insertions(+), 35 deletions(-) delete mode 100644 policyengine_api/helpers/__init__.py rename policyengine_api/{helpers => utils}/get_current_law.py (100%) create mode 100644 tests/fixtures/simulation_analysis_fixtures.py diff --git a/policyengine_api/helpers/__init__.py b/policyengine_api/helpers/__init__.py deleted file mode 100644 index e09cc904..00000000 --- a/policyengine_api/helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .get_current_law import get_current_law_policy_id diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 37776a2a..425c4caf 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -1,9 +1,7 @@ from flask import Blueprint from policyengine_api.services.economy_service import EconomyService -from policyengine_api.helpers import ( - validate_country, - get_current_law_policy_id, -) +from policyengine_api.utils import get_current_law_policy_id +from policyengine_api.utils.payload_validators import validate_country from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS from flask import request, Response import json diff --git a/policyengine_api/utils/__init__.py b/policyengine_api/utils/__init__.py index b86ff4cc..d225c663 100644 --- a/policyengine_api/utils/__init__.py +++ b/policyengine_api/utils/__init__.py @@ -1,3 +1,4 @@ from .json import * from .cache_utils import * from .singleton import Singleton +from .get_current_law import get_current_law_policy_id \ No newline at end of file diff --git a/policyengine_api/helpers/get_current_law.py b/policyengine_api/utils/get_current_law.py similarity index 100% rename from policyengine_api/helpers/get_current_law.py rename to policyengine_api/utils/get_current_law.py diff --git a/tests/fixtures/simulation_analysis_fixtures.py b/tests/fixtures/simulation_analysis_fixtures.py new file mode 100644 index 00000000..c2dbb26e --- /dev/null +++ b/tests/fixtures/simulation_analysis_fixtures.py @@ -0,0 +1,28 @@ +test_impact = { + "budget": 1000, + "intra_decile": 0.1, + "decile": 0.2, + "poverty": { + "poverty": 0.3, + "deep_poverty": 0.4, + }, + "poverty_by_gender": 0.5, + "poverty_by_race": {"poverty": 0.6}, + "inequality": 0.7, +} + +test_json = { + "currency": "USD", + "selected_version": "2023", + "time_period": "2023", + "impact": test_impact, + "policy_label": "Test Policy", + "policy": dict(policy_json="policy details"), + "region": "US", + "relevant_parameters": ["param1", "param2"], + "relevant_parameter_baseline_values": [ + {"param1": 100}, + {"param2": 200}, + ], + "audience": "Normal", +} \ No newline at end of file diff --git a/tests/python/test_simulation_analysis.py b/tests/python/test_simulation_analysis.py index 2b0be2f0..9a3292c6 100644 --- a/tests/python/test_simulation_analysis.py +++ b/tests/python/test_simulation_analysis.py @@ -9,6 +9,8 @@ execute_simulation_analysis, ) +from tests.fixtures.simulation_analysis_fixtures import test_json, test_impact + test_service = SimulationAnalysisService() @@ -19,36 +21,6 @@ def app(): return app -test_impact = { - "budget": 1000, - "intra_decile": 0.1, - "decile": 0.2, - "poverty": { - "poverty": 0.3, - "deep_poverty": 0.4, - }, - "poverty_by_gender": 0.5, - "poverty_by_race": {"poverty": 0.6}, - "inequality": 0.7, -} - -test_json = { - "currency": "USD", - "selected_version": "2023", - "time_period": "2023", - "impact": test_impact, - "policy_label": "Test Policy", - "policy": dict(policy_json="policy details"), - "region": "US", - "relevant_parameters": ["param1", "param2"], - "relevant_parameter_baseline_values": [ - {"param1": 100}, - {"param2": 200}, - ], - "audience": "Normal", -} - - def test_execute_simulation_analysis_existing_analysis(app, rest_client): with app.test_request_context(json=test_json): From bf944d51e16234c845ff73c5df0309da8f1ff05a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 18:06:03 +0100 Subject: [PATCH 10/14] fix: Update type hinting --- policyengine_api/services/ai_analysis_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/policyengine_api/services/ai_analysis_service.py b/policyengine_api/services/ai_analysis_service.py index fd72c932..9e6556a0 100644 --- a/policyengine_api/services/ai_analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -2,7 +2,7 @@ import os import json import time -from typing import Generator +from typing import Generator, Optional from policyengine_api.data import local_database @@ -15,7 +15,7 @@ class AIAnalysisService: def get_existing_analysis( self, prompt: str - ) -> Generator[str, None, None] | None: + ) -> Optional[Generator[str, None, None]]: """ Get existing analysis from the local database """ From 107416a9d240a54532d33302af40d21f166b53cb Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 18:06:46 +0100 Subject: [PATCH 11/14] chore: Lint --- changelog_entry.yaml | 4 +- .../routes/simulation_analysis_routes.py | 5 ++- policyengine_api/utils/__init__.py | 2 +- .../utils/payload_validators/__init__.py | 2 +- .../validate_sim_analysis_payload.py | 2 +- .../fixtures/simulation_analysis_fixtures.py | 2 +- tests/python/test_payload_validators.py | 39 ++++++++++++------- 7 files changed, 37 insertions(+), 19 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index 4b991734..e10fbb79 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -3,4 +3,6 @@ changed: - Refactored AI endpoints to match new routes/services/jobs architecture - Disabled default buffering on App Engine deployments for AI endpoints - - Updated relevant tests \ No newline at end of file + - Updated relevant tests + added: + - Testing for simulation payload validators \ No newline at end of file diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 27767503..fcb280b2 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -4,7 +4,10 @@ from policyengine_api.services.simulation_analysis_service import ( SimulationAnalysisService, ) -from policyengine_api.utils.payload_validators import validate_sim_analysis_payload, validate_country +from policyengine_api.utils.payload_validators import ( + validate_sim_analysis_payload, + validate_country, +) simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() diff --git a/policyengine_api/utils/__init__.py b/policyengine_api/utils/__init__.py index d225c663..07a267c6 100644 --- a/policyengine_api/utils/__init__.py +++ b/policyengine_api/utils/__init__.py @@ -1,4 +1,4 @@ from .json import * from .cache_utils import * from .singleton import Singleton -from .get_current_law import get_current_law_policy_id \ No newline at end of file +from .get_current_law import get_current_law_policy_id diff --git a/policyengine_api/utils/payload_validators/__init__.py b/policyengine_api/utils/payload_validators/__init__.py index d9c16308..da141ea8 100644 --- a/policyengine_api/utils/payload_validators/__init__.py +++ b/policyengine_api/utils/payload_validators/__init__.py @@ -1,2 +1,2 @@ from .validate_sim_analysis_payload import validate_sim_analysis_payload -from .validate_country import validate_country \ No newline at end of file +from .validate_country import validate_country diff --git a/policyengine_api/utils/payload_validators/validate_sim_analysis_payload.py b/policyengine_api/utils/payload_validators/validate_sim_analysis_payload.py index cac22bba..f3347c36 100644 --- a/policyengine_api/utils/payload_validators/validate_sim_analysis_payload.py +++ b/policyengine_api/utils/payload_validators/validate_sim_analysis_payload.py @@ -37,4 +37,4 @@ def validate_sim_analysis_payload(payload: dict): elif key in list_keys and not isinstance(value, list): return False, f"Key '{key}' must be a list" - return True, None \ No newline at end of file + return True, None diff --git a/tests/fixtures/simulation_analysis_fixtures.py b/tests/fixtures/simulation_analysis_fixtures.py index c2dbb26e..63e191d0 100644 --- a/tests/fixtures/simulation_analysis_fixtures.py +++ b/tests/fixtures/simulation_analysis_fixtures.py @@ -25,4 +25,4 @@ {"param2": 200}, ], "audience": "Normal", -} \ No newline at end of file +} diff --git a/tests/python/test_payload_validators.py b/tests/python/test_payload_validators.py index 63d82cd9..bdab5027 100644 --- a/tests/python/test_payload_validators.py +++ b/tests/python/test_payload_validators.py @@ -1,7 +1,10 @@ import pytest from typing import Dict, Any, Tuple -from policyengine_api.utils.payload_validators import validate_sim_analysis_payload +from policyengine_api.utils.payload_validators import ( + validate_sim_analysis_payload, +) + @pytest.fixture def valid_payload() -> Dict[str, Any]: @@ -17,12 +20,14 @@ def valid_payload() -> Dict[str, Any]: "relevant_parameter_baseline_values": [1.0, 2.0], } + def test_valid_payload(valid_payload): """Test that a valid payload passes validation""" is_valid, error = validate_sim_analysis_payload(valid_payload) assert is_valid is True assert error is None + def test_missing_required_key(valid_payload): """Test that missing required keys are detected""" del valid_payload["currency"] @@ -30,6 +35,7 @@ def test_missing_required_key(valid_payload): assert is_valid is False assert "Missing required keys: ['currency']" in error + def test_invalid_string_type(valid_payload): """Test that wrong type for string fields is detected""" valid_payload["currency"] = 123 # Should be string @@ -37,6 +43,7 @@ def test_invalid_string_type(valid_payload): assert is_valid is False assert "Key 'currency' must be a string" in error + def test_invalid_dict_type(valid_payload): """Test that wrong type for dictionary fields is detected""" valid_payload["impact"] = ["not", "a", "dict"] # Should be dict @@ -44,6 +51,7 @@ def test_invalid_dict_type(valid_payload): assert is_valid is False assert "Key 'impact' must be a dictionary" in error + def test_invalid_list_type(valid_payload): """Test that wrong type for list fields is detected""" valid_payload["relevant_parameters"] = "not a list" # Should be list @@ -51,6 +59,7 @@ def test_invalid_list_type(valid_payload): assert is_valid is False assert "Key 'relevant_parameters' must be a list" in error + def test_extra_keys_allowed(valid_payload): """Test that extra keys don't cause validation to fail""" valid_payload["extra_key"] = "some value" @@ -58,20 +67,24 @@ def test_extra_keys_allowed(valid_payload): assert is_valid is True assert error is None -@pytest.mark.parametrize("key", [ - "currency", - "selected_version", - "time_period", - "impact", - "policy_label", - "policy", - "region", - "relevant_parameters", - "relevant_parameter_baseline_values" -]) + +@pytest.mark.parametrize( + "key", + [ + "currency", + "selected_version", + "time_period", + "impact", + "policy_label", + "policy", + "region", + "relevant_parameters", + "relevant_parameter_baseline_values", + ], +) def test_individual_required_keys(valid_payload, key): """Test that each required key is properly checked""" del valid_payload[key] is_valid, error = validate_sim_analysis_payload(valid_payload) assert is_valid is False - assert f"Missing required keys: ['{key}']" in error \ No newline at end of file + assert f"Missing required keys: ['{key}']" in error From 8e94ba78515a18046b91fe52f72f1fbc13e35038 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 18:46:20 +0100 Subject: [PATCH 12/14] fix: Merge conflicts --- policyengine_api/endpoints/tracer_analysis.py | 95 ------------------- .../routes/simulation_analysis_routes.py | 6 +- .../routes/tracer_analysis_routes.py | 6 +- 3 files changed, 2 insertions(+), 105 deletions(-) delete mode 100644 policyengine_api/endpoints/tracer_analysis.py diff --git a/policyengine_api/endpoints/tracer_analysis.py b/policyengine_api/endpoints/tracer_analysis.py deleted file mode 100644 index 68085888..00000000 --- a/policyengine_api/endpoints/tracer_analysis.py +++ /dev/null @@ -1,95 +0,0 @@ -from policyengine_api.data import local_database -import json -from flask import Response, request, stream_with_context -from policyengine_api.helpers import validate_country -from policyengine_api.ai_prompts import tracer_analysis_prompt -from policyengine_api.utils.ai_analysis import ( - trigger_ai_analysis, - get_existing_analysis, -) -from policyengine_api.utils.tracer_analysis import parse_tracer_output -from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS -from typing import Generator - -# Rename the file and get_tracer method to something more logical (Done) -# Change the database call to select based only on household_id, policy_id, and country_id (Done) -# Add a placeholder for a parsing function (to be completed later) – ideally, have it return some sample output -# Access the prompt and add the parsed tracer output -# Pass the complete prompt to the get_analysis function and return its response - -# TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit - - -@validate_country -def execute_tracer_analysis( - country_id: str, -): - """Get a tracer from the local database. - - Args: - country_id (str): The country ID. - """ - - payload = request.json - - household_id = payload.get("household_id") - policy_id = payload.get("policy_id") - variable = payload.get("variable") - - api_version = COUNTRY_PACKAGE_VERSIONS[country_id] - - # Retrieve from the tracers table in the local database - row = local_database.query( - """ - SELECT * FROM tracers - WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ? - """, - (household_id, policy_id, country_id, api_version), - ).fetchone() - - # Fail if no tracer found - if row is None: - return Response( - status=404, - response={ - "message": "Unable to analyze household: no household simulation tracer found", - }, - ) - - # Parse the tracer output - tracer_output_list = json.loads(row["tracer_output"]) - try: - tracer_segment = parse_tracer_output(tracer_output_list, variable) - except Exception as e: - return Response( - status=500, - response={ - "message": "Error parsing tracer output", - "error": str(e), - }, - ) - - # Add the parsed tracer output to the prompt - prompt = tracer_analysis_prompt.format( - variable=variable, tracer_segment=tracer_segment - ) - - # If a calculated record exists for this prompt, return it as a - # streaming response - existing_analysis: Generator = get_existing_analysis(prompt) - if existing_analysis is not None: - return Response(stream_with_context(existing_analysis), status=200) - - # Otherwise, pass prompt to Claude, then return streaming function - try: - analysis: Generator = trigger_ai_analysis(prompt) - return Response(stream_with_context(analysis), status=200) - except Exception as e: - return Response( - status=500, - response={ - "message": "Error computing analysis", - "error": str(e), - }, - ) - diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index fcb280b2..5326989a 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -14,14 +14,10 @@ @simulation_analysis_bp.route("", methods=["POST"]) +@validate_country def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") - # Validate inbound country ID - invalid_country = validate_country(country_id) - if invalid_country: - return invalid_country - # Pop items from request payload and validate # where necessary payload = request.json diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 594028c1..35e5a403 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -10,13 +10,9 @@ @tracer_analysis_bp.route("", methods=["POST"]) +@validate_country def execute_tracer_analysis(country_id): - # Validate country ID - country_not_found = validate_country(country_id) - if country_not_found: - return country_not_found - payload = request.json is_payload_valid, message = validate_payload(payload) From d62e5ec1458efc18899750c115930152ee060787 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 18:48:17 +0100 Subject: [PATCH 13/14] fix: Update import route --- tests/python/test_validate_country.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_validate_country.py b/tests/python/test_validate_country.py index ddbd27a5..b162e4df 100644 --- a/tests/python/test_validate_country.py +++ b/tests/python/test_validate_country.py @@ -1,5 +1,5 @@ from flask import Response -from policyengine_api.helpers import validate_country +from policyengine_api.utils.payload_validators import validate_country @validate_country From 437dcbb316685092ababd07e51a10847fe486639 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 25 Nov 2024 18:49:57 +0100 Subject: [PATCH 14/14] fix: Move tracer analysis payload validator to proper folder --- .../routes/tracer_analysis_routes.py | 20 +++++-------------- .../utils/payload_validators/__init__.py | 1 + .../validate_tracer_analysis_payload.py | 11 ++++++++++ 3 files changed, 17 insertions(+), 15 deletions(-) create mode 100644 policyengine_api/utils/payload_validators/validate_tracer_analysis_payload.py diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 35e5a403..d15b0dd4 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,5 +1,8 @@ from flask import Blueprint, request, Response, stream_with_context -from policyengine_api.utils.payload_validators import validate_country +from policyengine_api.utils.payload_validators import ( + validate_country, + validate_tracer_analysis_payload, +) from policyengine_api.services.tracer_analysis_service import ( TracerAnalysisService, ) @@ -15,7 +18,7 @@ def execute_tracer_analysis(country_id): payload = request.json - is_payload_valid, message = validate_payload(payload) + is_payload_valid, message = validate_tracer_analysis_payload(payload) if not is_payload_valid: return Response( status=400, response=f"Invalid JSON data; details: {message}" @@ -70,16 +73,3 @@ def execute_tracer_analysis(country_id): ), 500, ) - - -def validate_payload(payload: dict): - # Validate payload - if not payload: - return False, "No payload provided" - - required_keys = ["household_id", "policy_id", "variable"] - for key in required_keys: - if key not in payload: - return False, f"Missing required key: {key}" - - return True, None diff --git a/policyengine_api/utils/payload_validators/__init__.py b/policyengine_api/utils/payload_validators/__init__.py index da141ea8..4bb15c1f 100644 --- a/policyengine_api/utils/payload_validators/__init__.py +++ b/policyengine_api/utils/payload_validators/__init__.py @@ -1,2 +1,3 @@ from .validate_sim_analysis_payload import validate_sim_analysis_payload +from .validate_tracer_analysis_payload import validate_tracer_analysis_payload from .validate_country import validate_country diff --git a/policyengine_api/utils/payload_validators/validate_tracer_analysis_payload.py b/policyengine_api/utils/payload_validators/validate_tracer_analysis_payload.py new file mode 100644 index 00000000..b17f3e35 --- /dev/null +++ b/policyengine_api/utils/payload_validators/validate_tracer_analysis_payload.py @@ -0,0 +1,11 @@ +def validate_tracer_analysis_payload(payload: dict): + # Validate payload + if not payload: + return False, "No payload provided" + + required_keys = ["household_id", "policy_id", "variable"] + for key in required_keys: + if key not in payload: + return False, f"Missing required key: {key}" + + return True, None