From 91756f6b9bb190490279ee2d7cd235bd166cb9ae Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Sun, 4 Aug 2024 01:13:03 -0400 Subject: [PATCH] add stats getter agent -- government essentials --- docker-compose.yml | 28 ++++ .../agent_services/__init__.py | 8 + .../government_essentials/__init__.py | 0 .../stats_getter_agent.py | 140 ++++++++++++++++++ .../core_services/control_plane.py | 25 +++- 5 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 snowflake_cybersyn_demo/agent_services/government_essentials/__init__.py create mode 100644 snowflake_cybersyn_demo/agent_services/government_essentials/stats_getter_agent.py diff --git a/docker-compose.yml b/docker-compose.yml index caa94b3..24ef636 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -124,6 +124,34 @@ services: retries: 5 start_period: 20s timeout: 10s + stats_getter_agent: + image: snowflake_cybersyn_demo:latest + command: sh -c "python -m snowflake_cybersyn_demo.agent_services.government_essentials.stats_getter_agent" + env_file: + - .env.docker + ports: + - "8005:8005" + volumes: + - ./snowflake_cybersyn_demo:/app/snowflake_cybersyn_demo # load local code change to container without the need of rebuild + - ./data:/app/data + - ./logging.ini:/app/logging.ini + depends_on: + rabbitmq: + condition: service_healthy + control_plane: + condition: service_healthy + platform: linux/amd64 + build: + context: . + dockerfile: ./Dockerfile + secrets: + - id_ed25519 + healthcheck: + test: wget --no-verbose --tries=1 http://0.0.0.0:8005/is_worker_running || exit 1 + interval: 30s + retries: 5 + start_period: 20s + timeout: 10s volumes: rabbitmq: secrets: diff --git a/snowflake_cybersyn_demo/agent_services/__init__.py b/snowflake_cybersyn_demo/agent_services/__init__.py index 2c13e4b..58a253b 100644 --- a/snowflake_cybersyn_demo/agent_services/__init__.py +++ b/snowflake_cybersyn_demo/agent_services/__init__.py @@ -16,12 +16,20 @@ from snowflake_cybersyn_demo.agent_services.funny_agent import ( agent_server as funny_agent_server, ) +from snowflake_cybersyn_demo.agent_services.government_essentials.stats_getter_agent import ( + agent_component as stats_getter_agent_component, +) +from snowflake_cybersyn_demo.agent_services.government_essentials.stats_getter_agent import ( + agent_server as stats_getter_agent_server, +) __all__ = [ "goods_getter_agent_component", "goods_getter_agent_server", "time_series_getter_agent_component", "time_series_getter_agent_server", + "stats_getter_agent_component", + "stats_getter_agent_server", "funny_agent_server", "funny_agent_component", ] diff --git a/snowflake_cybersyn_demo/agent_services/government_essentials/__init__.py b/snowflake_cybersyn_demo/agent_services/government_essentials/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/snowflake_cybersyn_demo/agent_services/government_essentials/stats_getter_agent.py b/snowflake_cybersyn_demo/agent_services/government_essentials/stats_getter_agent.py new file mode 100644 index 0000000..676689b --- /dev/null +++ b/snowflake_cybersyn_demo/agent_services/government_essentials/stats_getter_agent.py @@ -0,0 +1,140 @@ +import asyncio + +import uvicorn +from llama_agents import AgentService, ServiceComponent +from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue +from llama_index.agent.openai import OpenAIAgent +from llama_index.core.tools import FunctionTool +from llama_index.llms.openai import OpenAI +from snowflake.sqlalchemy import URL +from sqlalchemy import create_engine, text + +from snowflake_cybersyn_demo.utils import load_from_env + +message_queue_host = load_from_env("RABBITMQ_HOST") +message_queue_port = load_from_env("RABBITMQ_NODE_PORT") +message_queue_username = load_from_env("RABBITMQ_DEFAULT_USER") +message_queue_password = load_from_env("RABBITMQ_DEFAULT_PASS") +control_plane_host = load_from_env("CONTROL_PLANE_HOST") +control_plane_port = load_from_env("CONTROL_PLANE_PORT") +agent_host = load_from_env("STATS_GETTER_AGENT_HOST") +agent_port = load_from_env("STATS_GETTER_AGENT_PORT") +snowflake_user = load_from_env("SNOWFLAKE_USERNAME") +snowflake_password = load_from_env("SNOWFLAKE_PASSWORD") +snowflake_account = load_from_env("SNOWFLAKE_ACCOUNT") +snowflake_role = load_from_env("SNOWFLAKE_ROLE") +localhost = load_from_env("LOCALHOST") + + +# create agent server +message_queue = RabbitMQMessageQueue( + url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" +) + +SQL_QUERY_TEMPLATE = """ +SELECT DISTINCT + ts.variable_name +FROM cybersyn.datacommons_timeseries AS ts +JOIN cybersyn.geography_index AS geo ON (ts.geo_id = geo.geo_id) +WHERE geo.geo_name = '{city}' + AND geo.level IN ('City') + AND date >= '2015-01-01'; +""" + +AGENT_SYSTEM_PROMPT = """ +For a given query about a geographic and population statistic, your job is to first +find the statistical variables that exists in the database. + +Return the list of the three most relevant statistical variables that exist in the +database and that potentially match the object of the users query. + +Output your list in the following format: + +1. ..., +2. ..., +3. ... + +Be sure to use the exact variable names that were retrieved from the database tool! +""" + + +def get_list_of_statistical_variables(city: str, query: str) -> str: + """Returns a list of statistical variables that closely resemble the query. + + The list of statistical vars is represented as a string separated by '\n'. + """ + query = SQL_QUERY_TEMPLATE.format(city=city) + url = URL( + account=snowflake_account, + user=snowflake_user, + password=snowflake_password, + database="GOVERNMENT_ESSENTIALS", + schema="CYBERSYN", + warehouse="COMPUTE_WH", + role=snowflake_role, + ) + + engine = create_engine(url) + try: + connection = engine.connect() + results = connection.execute(text(query)) + finally: + connection.close() + + # process + results = [f"{ix+1}. {str(el[0])}" for ix, el in enumerate(results)] + results_str = "List of statistical variables that exist in the database are provided below. Please select one.:\n\n" + results_str += "\n".join(results) + + return results_str + + +statistics_getter_tool = FunctionTool.from_defaults( + fn=get_list_of_statistical_variables +) +agent = OpenAIAgent.from_tools( + [statistics_getter_tool], + system_prompt=AGENT_SYSTEM_PROMPT, + llm=OpenAI(model="gpt-4o-mini"), + verbose=True, +) + +agent_server = AgentService( + agent=agent, + message_queue=message_queue, + description="Retrieves the statistical variables that exist in the database that match the user's query.", + service_name="stats_getter_agent", + host=agent_host, + port=int(agent_port) if agent_port else None, +) +agent_component = ServiceComponent.from_service_definition(agent_server) + +app = agent_server._app + + +# launch +async def launch() -> None: + # register to message queue + start_consuming_callable = await agent_server.register_to_message_queue() + _ = asyncio.create_task(start_consuming_callable()) + + # register to control plane + await agent_server.register_to_control_plane( + control_plane_url=( + f"http://{control_plane_host}:{control_plane_port}" + if control_plane_port + else f"http://{control_plane_host}" + ) + ) + + cfg = uvicorn.Config( + agent_server._app, + host=localhost, + port=agent_server.port, + ) + server = uvicorn.Server(cfg) + await server.serve() + + +if __name__ == "__main__": + asyncio.run(launch()) diff --git a/snowflake_cybersyn_demo/core_services/control_plane.py b/snowflake_cybersyn_demo/core_services/control_plane.py index 97438fd..a2e093f 100644 --- a/snowflake_cybersyn_demo/core_services/control_plane.py +++ b/snowflake_cybersyn_demo/core_services/control_plane.py @@ -18,6 +18,7 @@ funny_agent_component, funny_agent_server, goods_getter_agent_component, + stats_getter_agent_component, time_series_getter_agent_component, ) from snowflake_cybersyn_demo.utils import load_from_env @@ -36,7 +37,7 @@ url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" ) - +# historical prices of a good pipeline timeseries_task_pipeline = QueryPipeline( chain=[ goods_getter_agent_component, @@ -51,16 +52,34 @@ (timeseries) data for a specified good from the database. """ +# government statistics pipeline +city_stats_pipeline = QueryPipeline( + chain=[ + stats_getter_agent_component, + human_component, + ], +) +city_stats_pipeline_orchestrator = PipelineOrchestrator(city_stats_pipeline) +city_stats_pipeline_desc = """Only used for getting geographic and demographic +statistics for a specified city. +""" + +# general pipeline general_pipeline = QueryPipeline(chain=[funny_agent_component]) general_pipeline_orchestrator = PipelineOrchestrator(general_pipeline) pipeline_orchestrator = OrchestratorRouter( - selector=PydanticSingleSelector.from_defaults(llm=OpenAI()), + selector=PydanticSingleSelector.from_defaults(llm=OpenAI("gpt-4o-mini")), orchestrators=[ timeseries_pipeline_orchestrator, + city_stats_pipeline_orchestrator, general_pipeline_orchestrator, ], - choices=[timeseries_task_pipeline_desc, funny_agent_server.description], + choices=[ + timeseries_task_pipeline_desc, + city_stats_pipeline_desc, + funny_agent_server.description, + ], ) # setup control plane