From 865250fb59ac23495a7d5d870653209512b49368 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 Nov 2024 00:21:44 +0100 Subject: [PATCH] chore: Lint and changelog --- changelog_entry.yaml | 5 + policyengine_api/api.py | 14 +- .../routes/simulation_analysis_routes.py | 19 +- .../routes/tracer_analysis_routes.py | 26 ++- .../services/ai_analysis_service.py | 13 +- policyengine_api/services/economy_service.py | 1 + policyengine_api/services/job_service.py | 1 + policyengine_api/services/policy_service.py | 1 + .../services/simulation_analysis_service.py | 146 +++++++------ .../services/tracer_analysis_service.py | 205 +++++++++--------- 10 files changed, 235 insertions(+), 196 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..d1bc7e4e 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + changed: + - Refactored AI endpoints to match new routes/services/jobs architecture + - Disabled default buffering on App Engine deployments for AI endpoints \ No newline at end of file diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 162fa00d..3f3bd725 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -14,7 +14,9 @@ # Endpoints from policyengine_api.routes.economy_routes import economy_bp -from policyengine_api.routes.simulation_analysis_routes import simulation_analysis_bp +from policyengine_api.routes.simulation_analysis_routes import ( + simulation_analysis_bp, +) from policyengine_api.routes.tracer_analysis_routes import tracer_analysis_bp from .endpoints import ( get_home, @@ -95,9 +97,8 @@ # Routes for AI analysis of economy microsim runs app.register_blueprint( - simulation_analysis_bp, - url_prefix="//simulation-analysis" - ) + simulation_analysis_bp, url_prefix="//simulation-analysis" +) app.route("//user-policy", methods=["POST"])(set_user_policy) @@ -115,7 +116,10 @@ app.route("/simulations", methods=["GET"])(get_simulations) -app.register_blueprint(tracer_analysis_bp, url_prefix="//tracer-analysis") +app.register_blueprint( + tracer_analysis_bp, url_prefix="//tracer-analysis" +) + @app.route("/liveness-check", methods=["GET"]) def liveness_check(): diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index da7dfe77..73ad48fd 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,10 +1,13 @@ from flask import Blueprint, request, Response, stream_with_context from policyengine_api.helpers import validate_country -from policyengine_api.services.simulation_analysis_service import SimulationAnalysisService +from policyengine_api.services.simulation_analysis_service import ( + SimulationAnalysisService, +) simulation_analysis_bp = Blueprint("simulation_analysis", __name__) simulation_analysis_service = SimulationAnalysisService() + @simulation_analysis_bp.route("", methods=["POST"]) def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") @@ -20,7 +23,9 @@ def execute_simulation_analysis(country_id): is_payload_valid, message = validate_payload(payload) if not is_payload_valid: - return Response(status=400, response=f"Invalid JSON data; details: {message}") + return Response( + status=400, response=f"Invalid JSON data; details: {message}" + ) currency: str = payload.get("currency") selected_version: str = payload.get("selected_version") @@ -58,16 +63,18 @@ def execute_simulation_analysis(country_id): # Set header to prevent buffering on Google App Engine deployment # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) - response.headers['X-Accel-Buffering'] = 'no' + response.headers["X-Accel-Buffering"] = "no" return response except Exception as e: return { "status": "error", - "message": "An error occurred while executing the simulation analysis. Details: " + str(e), + "message": "An error occurred while executing the simulation analysis. Details: " + + str(e), "result": None, }, 500 + def validate_payload(payload: dict): # Check if all required keys are present; note # that the audience key is optional @@ -97,7 +104,7 @@ def validate_payload(payload: dict): missing_keys = [key for key in required_keys if key not in payload] if missing_keys: return False, f"Missing required keys: {missing_keys}" - + # Check if all keys are of the right type for key, value in payload.items(): if key in str_keys and not isinstance(value, str): @@ -107,4 +114,4 @@ def validate_payload(payload: dict): elif key in list_keys and not isinstance(value, list): return False, f"Key '{key}' must be a list" - return True, None \ No newline at end of file + return True, None diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index 2fcdae9a..4194ac98 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,10 +1,13 @@ from flask import Blueprint, request, Response, stream_with_context from policyengine_api.helpers import validate_country -from policyengine_api.services.tracer_analysis_service import TracerAnalysisService +from policyengine_api.services.tracer_analysis_service import ( + TracerAnalysisService, +) tracer_analysis_bp = Blueprint("tracer_analysis", __name__) tracer_analysis_service = TracerAnalysisService() + @tracer_analysis_bp.route("", methods=["POST"]) def execute_tracer_analysis(country_id): @@ -17,7 +20,9 @@ def execute_tracer_analysis(country_id): is_payload_valid, message = validate_payload(payload) if not is_payload_valid: - return Response(status=400, response=f"Invalid JSON data; details: {message}") + return Response( + status=400, response=f"Invalid JSON data; details: {message}" + ) household_id = payload.get("household_id") policy_id = payload.get("policy_id") @@ -36,20 +41,21 @@ def execute_tracer_analysis(country_id): ), status=200, ) - + # Set header to prevent buffering on Google App Engine deployment # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) - response.headers['X-Accel-Buffering'] = 'no' - + response.headers["X-Accel-Buffering"] = "no" + return response except Exception as e: return { - "status": "error", - "message": "An error occurred while executing the tracer analysis. Details: " - + str(e), - "result": None, + "status": "error", + "message": "An error occurred while executing the tracer analysis. Details: " + + str(e), + "result": None, }, 500 + def validate_payload(payload: dict): # Validate payload if not payload: @@ -60,4 +66,4 @@ def validate_payload(payload: dict): if key not in payload: return False, f"Missing required key: {key}" - return True, None \ No newline at end of file + return True, None diff --git a/policyengine_api/services/ai_analysis_service.py b/policyengine_api/services/ai_analysis_service.py index 8e71e2ee..fd72c932 100644 --- a/policyengine_api/services/ai_analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -5,6 +5,7 @@ from typing import Generator from policyengine_api.data import local_database + class AIAnalysisService: """ Base class for various AI analysis-based services, @@ -12,7 +13,9 @@ class AIAnalysisService: local database table """ - def get_existing_analysis(self, prompt: str) -> Generator[str, None, None] | None: + def get_existing_analysis( + self, prompt: str + ) -> Generator[str, None, None] | None: """ Get existing analysis from the local database """ @@ -43,9 +46,11 @@ def generate(): return generate() def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: - + # Configure a Claude client - claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + claude_client = anthropic.Anthropic( + api_key=os.getenv("ANTHROPIC_API_KEY") + ) def generate(): chunk_size = 5 @@ -84,4 +89,4 @@ def generate(): (prompt, response_text, "ok"), ) - return generate() \ No newline at end of file + return generate() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index da7bd34c..2806506b 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -18,6 +18,7 @@ class EconomyService: to the /economy route, which does not have its own table; therefore, it connects with other services to access their respective tables """ + def get_economic_impact( self, country_id, diff --git a/policyengine_api/services/job_service.py b/policyengine_api/services/job_service.py index f8d61c69..721e44ec 100644 --- a/policyengine_api/services/job_service.py +++ b/policyengine_api/services/job_service.py @@ -24,6 +24,7 @@ class JobService(metaclass=Singleton): jobs. This is not connected to any routes or tables, but interfaces with the Redis queue to enqueue jobs and track their status. """ + def __init__(self): self.recent_jobs = {} diff --git a/policyengine_api/services/policy_service.py b/policyengine_api/services/policy_service.py index 4367ff25..0a820f83 100644 --- a/policyengine_api/services/policy_service.py +++ b/policyengine_api/services/policy_service.py @@ -7,6 +7,7 @@ class PolicyService: this will be connected to the /policy route and is partially connected to the policy database table """ + def get_policy_json(self, country_id, policy_id): try: policy_json = database.query( diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index ed5c62bc..99aff720 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -2,37 +2,74 @@ from policyengine_api.services.ai_analysis_service import AIAnalysisService + class SimulationAnalysisService(AIAnalysisService): - """ - Service for generating AI analysis of economy-wide simulation - runs; this is connected with the simulation_analysis route and - analysis database table - """ - def __init__(self): - super().__init__() - - def execute_analysis( - self, - country_id: str, - currency: str, - selected_version: str, - time_period: str, - impact: dict, - policy_label: str, - policy: dict, - region: str, - relevant_parameters: list, - relevant_parameter_baseline_values: list, - audience: str | None, - ): - - # Check if the region is enhanced_cps - is_enhanced_cps = "enhanced_cps" in region - - print("Generating prompt for economy-wide simulation analysis") - - # Create prompt based on data - prompt = self._generate_simulation_analysis_prompt( + """ + Service for generating AI analysis of economy-wide simulation + runs; this is connected with the simulation_analysis route and + analysis database table + """ + + def __init__(self): + super().__init__() + + def execute_analysis( + self, + country_id: str, + currency: str, + selected_version: str, + time_period: str, + impact: dict, + policy_label: str, + policy: dict, + region: str, + relevant_parameters: list, + relevant_parameter_baseline_values: list, + audience: str | None, + ): + + # Check if the region is enhanced_cps + is_enhanced_cps = "enhanced_cps" in region + + print("Generating prompt for economy-wide simulation analysis") + + # Create prompt based on data + prompt = self._generate_simulation_analysis_prompt( + time_period, + region, + currency, + policy, + impact, + relevant_parameters, + relevant_parameter_baseline_values, + is_enhanced_cps, + selected_version, + country_id, + policy_label, + ) + + # Add audience description to end + prompt += self.audience_descriptions[audience] + + print("Checking if AI analysis already exists for this prompt") + # If a calculated record exists for this prompt, return it as a + # streaming response + existing_analysis = self.get_existing_analysis(prompt) + if existing_analysis is not None: + return existing_analysis + + print( + "Found no existing AI analysis; triggering new analysis with Claude" + ) + # Otherwise, pass prompt to Claude, then return streaming function + try: + analysis = self.trigger_ai_analysis(prompt) + return analysis + except Exception as e: + raise e + + def _generate_simulation_analysis_prompt( + self, time_period, region, currency, @@ -44,41 +81,8 @@ def execute_analysis( selected_version, country_id, policy_label, - ) - - # Add audience description to end - prompt += self.audience_descriptions[audience] - - print("Checking if AI analysis already exists for this prompt") - # If a calculated record exists for this prompt, return it as a - # streaming response - existing_analysis = self.get_existing_analysis(prompt) - if existing_analysis is not None: - return existing_analysis - - print("Found no existing AI analysis; triggering new analysis with Claude") - # Otherwise, pass prompt to Claude, then return streaming function - try: - analysis = self.trigger_ai_analysis(prompt) - return analysis - except Exception as e: - raise e - - def _generate_simulation_analysis_prompt( - self, - time_period, - region, - currency, - policy, - impact, - relevant_parameters, - relevant_parameter_baseline_values, - is_enhanced_cps, - selected_version, - country_id, - policy_label, - ): - return f""" + ): + return f""" I'm using PolicyEngine, a free, open source tool to compute the impact of public policy. I'm writing up an economic analysis of a hypothetical tax-benefit policy reform. Please write the analysis for me using the details below, in @@ -184,9 +188,9 @@ def _generate_simulation_analysis_prompt( impact["inequality"], )} """ - - audience_descriptions = { - "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", - "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", - "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", - } \ No newline at end of file + + audience_descriptions = { + "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", + "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", + "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", + } diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index 19e1fd98..162245d2 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -6,114 +6,119 @@ import anthropic from policyengine_api.services.ai_analysis_service import AIAnalysisService + class TracerAnalysisService(AIAnalysisService): - def __init__(self): - super().__init__() - - def execute_analysis( - self, - country_id: str, - household_id: str, - policy_id: str, - variable: str, - ): - - api_version = COUNTRY_PACKAGE_VERSIONS[country_id] - - # Retrieve tracer record from table - try: - tracer: list[str] = self.get_tracer( - country_id, - household_id, - policy_id, - api_version, - ) - except Exception as e: - raise e - - # Parse the tracer output for our given variable - try: - tracer_segment: list[str] = self._parse_tracer_output(tracer, variable) - except Exception as e: - print(f"Error parsing tracer output: {str(e)}") - raise e - - # Add the parsed tracer output to the prompt - prompt = self.prompt_template.format( - variable=variable, tracer_segment=tracer_segment - ) - - # If a calculated record exists for this prompt, return it as a - # streaming response - existing_analysis: Generator = self.get_existing_analysis(prompt) - if existing_analysis is not None: - return existing_analysis - - # Otherwise, pass prompt to Claude, then return streaming function - try: - analysis: Generator = self.trigger_ai_analysis(prompt) - return analysis - except Exception as e: - print(f"Error generating AI analysis within tracer analysis service: {str(e)}") - raise e - - def get_tracer( + def __init__(self): + super().__init__() + + def execute_analysis( + self, + country_id: str, + household_id: str, + policy_id: str, + variable: str, + ): + + api_version = COUNTRY_PACKAGE_VERSIONS[country_id] + + # Retrieve tracer record from table + try: + tracer: list[str] = self.get_tracer( + country_id, + household_id, + policy_id, + api_version, + ) + except Exception as e: + raise e + + # Parse the tracer output for our given variable + try: + tracer_segment: list[str] = self._parse_tracer_output( + tracer, variable + ) + except Exception as e: + print(f"Error parsing tracer output: {str(e)}") + raise e + + # Add the parsed tracer output to the prompt + prompt = self.prompt_template.format( + variable=variable, tracer_segment=tracer_segment + ) + + # If a calculated record exists for this prompt, return it as a + # streaming response + existing_analysis: Generator = self.get_existing_analysis(prompt) + if existing_analysis is not None: + return existing_analysis + + # Otherwise, pass prompt to Claude, then return streaming function + try: + analysis: Generator = self.trigger_ai_analysis(prompt) + return analysis + except Exception as e: + print( + f"Error generating AI analysis within tracer analysis service: {str(e)}" + ) + raise e + + def get_tracer( self, country_id: str, household_id: str, policy_id: str, api_version: str, - ) -> list: - try: - # Retrieve from the tracers table in the local database - row = local_database.query( - """ + ) -> list: + try: + # Retrieve from the tracers table in the local database + row = local_database.query( + """ SELECT * FROM tracers WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ? """, - (household_id, policy_id, country_id, api_version), - ).fetchone() - - if row is None: - raise KeyError("No tracer found for this household") - - tracer_output_list = json.loads(row["tracer_output"]) - return tracer_output_list - - except Exception as e: - print(f"Error getting existing tracer analysis: {str(e)}") - raise e - - def _parse_tracer_output(self, tracer_output, target_variable): - result = [] - target_indent = None - capturing = False - - # Create a regex pattern to match the exact variable name - # This will match the variable name followed by optional whitespace, - # then optional angle brackets with any content, then optional whitespace - pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" - - for line in tracer_output: - # Count leading spaces to determine indentation level - indent = len(line) - len(line.strip()) - - # Check if this line matches our target variable - match = re.match(pattern, line) - if match and not capturing: - target_indent = indent - capturing = True - result.append(line) - elif capturing: - # Stop capturing if we encounter a line with less indentation than the target - if indent <= target_indent: - break - # Capture dependencies (lines with greater indentation) - result.append(line) - - return result - - prompt_template = f"""{anthropic.HUMAN_PROMPT} You are an AI assistant explaining US policy calculations. + (household_id, policy_id, country_id, api_version), + ).fetchone() + + if row is None: + raise KeyError("No tracer found for this household") + + tracer_output_list = json.loads(row["tracer_output"]) + return tracer_output_list + + except Exception as e: + print(f"Error getting existing tracer analysis: {str(e)}") + raise e + + def _parse_tracer_output(self, tracer_output, target_variable): + result = [] + target_indent = None + capturing = False + + # Create a regex pattern to match the exact variable name + # This will match the variable name followed by optional whitespace, + # then optional angle brackets with any content, then optional whitespace + pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" + + for line in tracer_output: + # Count leading spaces to determine indentation level + indent = len(line) - len(line.strip()) + + # Check if this line matches our target variable + match = re.match(pattern, line) + if match and not capturing: + target_indent = indent + capturing = True + result.append(line) + elif capturing: + # Stop capturing if we encounter a line with less indentation than the target + if indent <= target_indent: + break + # Capture dependencies (lines with greater indentation) + result.append(line) + + return result + + prompt_template = f"""{anthropic.HUMAN_PROMPT} You are an AI assistant explaining US policy calculations. The user has run a simulation for the variable '{{variable}}'. Here's the tracer output: {{tracer_segment}} @@ -126,4 +131,4 @@ def _parse_tracer_output(self, tracer_output, target_variable): Keep your explanation concise but informative, suitable for a general audience. Do not start with phrases like "Certainly!" or "Here's an explanation. It will be rendered as markdown, so preface $ with \. - {anthropic.AI_PROMPT}""" \ No newline at end of file + {anthropic.AI_PROMPT}"""