From 88d953df8774f53453d8984a639eb555a1fb46b4 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Tue, 6 Aug 2024 14:17:51 -0400 Subject: [PATCH] govt essentials agent working --- .../time_series_getter_agent.py | 4 +- .../stats_fulfiller_agent.py | 8 +- snowflake_cybersyn_demo/apps/controller.py | 10 +- snowflake_cybersyn_demo/apps/streamlit.py | 91 +++++++++++-------- 4 files changed, 67 insertions(+), 46 deletions(-) diff --git a/snowflake_cybersyn_demo/agent_services/financial_and_economic_essentials/time_series_getter_agent.py b/snowflake_cybersyn_demo/agent_services/financial_and_economic_essentials/time_series_getter_agent.py index aa19a1b..61932dd 100644 --- a/snowflake_cybersyn_demo/agent_services/financial_and_economic_essentials/time_series_getter_agent.py +++ b/snowflake_cybersyn_demo/agent_services/financial_and_economic_essentials/time_series_getter_agent.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import Dict, List +from typing import Any, Dict, List import uvicorn from llama_agents import AgentService, ServiceComponent @@ -102,7 +102,7 @@ def get_time_series_of_good(good: str) -> str: return results_str -def perform_price_aggregation(json_str: str) -> str: +def perform_price_aggregation(json_str: str) -> List[Dict[str, Any]]: """Perform price aggregation on the time series data.""" timeseries_data = json.loads(json_str) good = timeseries_data[0]["good"] diff --git a/snowflake_cybersyn_demo/agent_services/government_essentials/stats_fulfiller_agent.py b/snowflake_cybersyn_demo/agent_services/government_essentials/stats_fulfiller_agent.py index 9356786..cc2a99f 100644 --- a/snowflake_cybersyn_demo/agent_services/government_essentials/stats_fulfiller_agent.py +++ b/snowflake_cybersyn_demo/agent_services/government_essentials/stats_fulfiller_agent.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from typing import Dict, List +from typing import Any, Dict, List import uvicorn from llama_agents import AgentService, ServiceComponent @@ -12,10 +12,10 @@ from snowflake.sqlalchemy import URL from sqlalchemy import create_engine, text -logger = logging.getLogger(__name__) +from snowflake_cybersyn_demo.utils import load_from_env +logger = logging.getLogger(__name__) -from snowflake_cybersyn_demo.utils import load_from_env message_queue_host = load_from_env("RABBITMQ_HOST") message_queue_port = load_from_env("RABBITMQ_NODE_PORT") @@ -113,7 +113,7 @@ def get_time_series_of_statistic_variable( return results_str -def perform_date_value_aggregation(json_str: str) -> str: +def perform_date_value_aggregation(json_str: str) -> List[Dict[str, Any]]: """Perform value aggregation on the time series data.""" timeseries_data = json.loads(json_str) variable = timeseries_data[0]["variable"] diff --git a/snowflake_cybersyn_demo/apps/controller.py b/snowflake_cybersyn_demo/apps/controller.py index 6c251f8..fa826f4 100644 --- a/snowflake_cybersyn_demo/apps/controller.py +++ b/snowflake_cybersyn_demo/apps/controller.py @@ -3,7 +3,7 @@ import queue from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, Generator, List, Optional +from typing import Any, Callable, Generator, List, Optional import pandas as pd import streamlit as st @@ -197,14 +197,16 @@ def task_selection_handler() -> None: return task_selection_handler def infer_task_type(self, task_res: TaskResult) -> str: - def try_parse_as_json(text: str) -> Optional[Dict]: + def try_parse_as_json(text: str) -> Any: try: return json.loads(text) except json.JSONDecodeError: - return {} + return None if task_res_json := try_parse_as_json(task_res.result): if "good" in task_res_json[0]: - return "timeseries" + return "timeseries-good" + if "variable" in task_res_json[0]: + return "timeseries-city-stat" return "text" diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index ef2087b..9cc49b3 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -18,6 +18,9 @@ from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.time_series_getter_agent import ( perform_price_aggregation, ) +from snowflake_cybersyn_demo.agent_services.government_essentials.stats_fulfiller_agent import ( + perform_date_value_aggregation, +) from snowflake_cybersyn_demo.apps.controller import Controller from snowflake_cybersyn_demo.apps.final_task_consumer import FinalTaskConsumer @@ -32,13 +35,15 @@ @st.cache_resource -def startup() -> Tuple[ - Controller, - queue.Queue[TaskResult], - FinalTaskConsumer, - queue.Queue[HumanRequest], - queue.Queue[str], -]: +def startup() -> ( + Tuple[ + Controller, + queue.Queue[TaskResult], + FinalTaskConsumer, + queue.Queue[HumanRequest], + queue.Queue[str], + ] +): from snowflake_cybersyn_demo.additional_services.human_in_the_loop import ( human_input_request_queue, human_input_result_queue, @@ -61,7 +66,9 @@ async def start_consuming_human_tasks(hs: HumanService) -> None: ) ) - consuming_callable = await message_queue.register_consumer(hs.as_consumer()) + consuming_callable = await message_queue.register_consumer( + hs.as_consumer() + ) ht_task = asyncio.create_task(consuming_callable()) # noqa: F841 @@ -155,25 +162,6 @@ async def start_consuming_finalized_tasks( ) -def chat_window() -> None: - pass - # with st.sidebar: - - # @st.experimental_fragment(run_every="5s") - # def show_chat_window() -> None: - # messages_container = st.container(height=500) - # with messages_container: - # if st.session_state.current_task: - # messages = [m.dict() for m in st.session_state.current_task.history] - # for message in messages: - # with st.chat_message(message["role"]): - # st.markdown(message["content"]) - # else: - # st.empty() - - # show_chat_window() - - @st.experimental_fragment(run_every="5s") def task_df() -> None: st.text("Task Status") @@ -224,7 +212,9 @@ def task_df() -> None: st.text_input( "Provide human input", key="human_input", - on_change=controller.get_human_input_handler(human_input_result_queue), + on_change=controller.get_human_input_handler( + human_input_result_queue + ), ) show_task_res = ( @@ -240,22 +230,47 @@ def task_df() -> None: task_type = controller.infer_task_type(task_res) timeseries_data = None - if task_type == "timeseries": + value_key: str = "" + object_key: str = "" + color: str = "" + if task_type == "timeseries-good": try: - timeseries_data = perform_price_aggregation(task_res.result) + timeseries_data = perform_price_aggregation( + task_res.result + ) + value_key = "price" + object_key = "good" + color = "#FF91AF" + except json.JSONDecodeError: + logger.info("Could not decode task_res") + pass + elif task_type == "timeseries-city-stat": + try: + timeseries_data = perform_date_value_aggregation( + task_res.result + ) + value_key = "value" + object_key = "variable" + color = "#73CED0" except json.JSONDecodeError: logger.info("Could not decode task_res") pass with task_res_container: if timeseries_data: - title = timeseries_data[0]["good"] - timeseries_data = { + title = timeseries_data[0][object_key] + chart_data = { "dates": [el["date"] for el in timeseries_data], - "price": [el["price"] for el in timeseries_data], + value_key: [el[value_key] for el in timeseries_data], } st.header(title) - st.bar_chart(data=timeseries_data, x="dates", y="price", height=400) + st.bar_chart( + data=chart_data, + x="dates", + y=value_key, + height=400, + color=color, + ) else: st.write(task_res.result) @@ -273,7 +288,9 @@ def process_completed_tasks(completed_queue: queue.Queue) -> None: logger.info("task result queue is empty.") if task_res: - controller.update_associated_task_to_completed_status(task_res=task_res) + controller.update_associated_task_to_completed_status( + task_res=task_res + ) process_completed_tasks(completed_queue=completed_tasks_queue) @@ -291,7 +308,9 @@ def process_human_input_requests( logger.info("human request queue is empty.") if human_req: - controller.update_associated_task_to_human_required_status(human_req=human_req) + controller.update_associated_task_to_human_required_status( + human_req=human_req + ) process_human_input_requests(human_requests_queue=human_input_request_queue)