From 44308a44c5930e2657f8e4555e74409e03e6202c Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Sat, 3 Aug 2024 02:28:31 -0400 Subject: [PATCH] router --- .../agent_services/__init__.py | 27 ++++++++++++ .../agent_services/funny_agent.py | 28 ++----------- snowflake_cybersyn_demo/apps/controller.py | 16 +++++++- snowflake_cybersyn_demo/apps/streamlit.py | 41 ++++++++++++------- .../core_services/control_plane.py | 38 +++++++++++++---- 5 files changed, 101 insertions(+), 49 deletions(-) diff --git a/snowflake_cybersyn_demo/agent_services/__init__.py b/snowflake_cybersyn_demo/agent_services/__init__.py index e69de29..2c13e4b 100644 --- a/snowflake_cybersyn_demo/agent_services/__init__.py +++ b/snowflake_cybersyn_demo/agent_services/__init__.py @@ -0,0 +1,27 @@ +from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.goods_getter_agent import ( + agent_component as goods_getter_agent_component, +) +from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.goods_getter_agent import ( + agent_server as goods_getter_agent_server, +) +from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.time_series_getter_agent import ( + agent_component as time_series_getter_agent_component, +) +from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.time_series_getter_agent import ( + agent_server as time_series_getter_agent_server, +) +from snowflake_cybersyn_demo.agent_services.funny_agent import ( + agent_component as funny_agent_component, +) +from snowflake_cybersyn_demo.agent_services.funny_agent import ( + agent_server as funny_agent_server, +) + +__all__ = [ + "goods_getter_agent_component", + "goods_getter_agent_server", + "time_series_getter_agent_component", + "time_series_getter_agent_server", + "funny_agent_server", + "funny_agent_component", +] diff --git a/snowflake_cybersyn_demo/agent_services/funny_agent.py b/snowflake_cybersyn_demo/agent_services/funny_agent.py index cec39db..be84b1e 100644 --- a/snowflake_cybersyn_demo/agent_services/funny_agent.py +++ b/snowflake_cybersyn_demo/agent_services/funny_agent.py @@ -1,12 +1,10 @@ import asyncio -from pathlib import Path import uvicorn from llama_agents import AgentService, ServiceComponent from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue from llama_index.agent.openai import OpenAIAgent -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex -from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolMetadata +from llama_index.core.tools import FunctionTool from llama_index.llms.openai import OpenAI from snowflake_cybersyn_demo.utils import load_from_env @@ -36,34 +34,16 @@ def get_the_secret_fact() -> str: secret_fact_tool = FunctionTool.from_defaults(fn=get_the_secret_fact) -# rag tool -data_path = Path(Path(__file__).parents[2].absolute(), "data").as_posix() -print(data_path) -loader = SimpleDirectoryReader(input_dir=data_path) -documents = loader.load_data() -index = VectorStoreIndex.from_documents(documents) -query_engine = index.as_query_engine(llm=OpenAI(model="gpt-4o")) -query_engine_tool = QueryEngineTool( - query_engine=query_engine, - metadata=ToolMetadata( - name="paul_graham_tool", - description=( - "Provides information about Paul Graham and his written essays." - ), - ), -) - - agent = OpenAIAgent.from_tools( - [secret_fact_tool, query_engine_tool], - system_prompt="Knows about Paul Graham, the secret fact, and is able to tell a funny joke.", + [secret_fact_tool], + system_prompt="Knows the secret fact, and can tell funny jokes.", llm=OpenAI(model="gpt-4o"), verbose=True, ) agent_server = AgentService( agent=agent, message_queue=message_queue, - description="Useful for everything but math, and especially telling funny jokes and anything about Paul Graham.", + description="Cannot get timeseries data of specified good, but can handle all other queries.", service_name="funny_agent", host=funny_agent_host, port=int(funny_agent_port) if funny_agent_port else None, diff --git a/snowflake_cybersyn_demo/apps/controller.py b/snowflake_cybersyn_demo/apps/controller.py index e87b218..6c251f8 100644 --- a/snowflake_cybersyn_demo/apps/controller.py +++ b/snowflake_cybersyn_demo/apps/controller.py @@ -1,8 +1,9 @@ +import json import logging import queue from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Generator, List, Optional +from typing import Any, Callable, Dict, Generator, List, Optional import pandas as pd import streamlit as st @@ -194,3 +195,16 @@ def task_selection_handler() -> None: pass # handle this better return task_selection_handler + + def infer_task_type(self, task_res: TaskResult) -> str: + def try_parse_as_json(text: str) -> Optional[Dict]: + try: + return json.loads(text) + except json.JSONDecodeError: + return {} + + if task_res_json := try_parse_as_json(task_res.result): + if "good" in task_res_json[0]: + return "timeseries" + + return "text" diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index 1fd4e56..7650a8b 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -238,25 +238,36 @@ def task_df() -> None: and st.session_state.current_task.status == "completed" ) + task_res_container = st.container(height=500) if show_task_res: if task_res := controller.get_task_result( st.session_state.current_task.task_id ): - try: - timeseries_data = perform_price_aggregation(task_res.result) - except json.JSONDecodeError: - logger.info("Could not decode task_res") - pass - title = timeseries_data[0]["good"] - timeseries_data = { - "dates": [el["date"] for el in timeseries_data], - "price": [el["price"] for el in timeseries_data], - } - with st.container(height=500): - st.header(title) - st.bar_chart( - data=timeseries_data, x="dates", y="price", height=400 - ) + task_type = controller.infer_task_type(task_res) + + timeseries_data = None + if task_type == "timeseries": + try: + timeseries_data = perform_price_aggregation( + task_res.result + ) + except json.JSONDecodeError: + logger.info("Could not decode task_res") + pass + + with task_res_container: + if timeseries_data: + title = timeseries_data[0]["good"] + timeseries_data = { + "dates": [el["date"] for el in timeseries_data], + "price": [el["price"] for el in timeseries_data], + } + st.header(title) + st.bar_chart( + data=timeseries_data, x="dates", y="price", height=400 + ) + else: + st.write(task_res.result) task_df() diff --git a/snowflake_cybersyn_demo/core_services/control_plane.py b/snowflake_cybersyn_demo/core_services/control_plane.py index 74ad267..97674ef 100644 --- a/snowflake_cybersyn_demo/core_services/control_plane.py +++ b/snowflake_cybersyn_demo/core_services/control_plane.py @@ -3,16 +3,18 @@ import uvicorn from llama_agents import ControlPlaneServer, PipelineOrchestrator from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue -from llama_index.core.query_pipeline import QueryPipeline +from llama_index.core.query_pipeline import QueryPipeline, RouterComponent +from llama_index.core.selectors import PydanticSingleSelector +from llama_index.llms.openai import OpenAI from snowflake_cybersyn_demo.additional_services.human_in_the_loop import ( human_component, ) -from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.goods_getter_agent import ( - agent_component as goods_getter_component, -) -from snowflake_cybersyn_demo.agent_services.financial_and_economic_essentials.time_series_getter_agent import ( - agent_component as time_series_getter_component, +from snowflake_cybersyn_demo.agent_services import ( + funny_agent_component, + funny_agent_server, + goods_getter_agent_component, + time_series_getter_agent_component, ) from snowflake_cybersyn_demo.utils import load_from_env @@ -30,13 +32,31 @@ url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" ) -pipeline = QueryPipeline( + +timeseries_task_pipeline = QueryPipeline( chain=[ - goods_getter_component, + goods_getter_agent_component, human_component, - time_series_getter_component, + time_series_getter_agent_component, + ], +) +timeseries_task_pipeline_desc = ( + "Only used for getting timeseries data from the database." +) + +pipeline = QueryPipeline( + chain=[ + RouterComponent( + selector=PydanticSingleSelector.from_defaults(llm=OpenAI()), + choices=[ + funny_agent_server.description, + timeseries_task_pipeline_desc, + ], + components=[funny_agent_component, timeseries_task_pipeline], + ) ] ) + pipeline_orchestrator = PipelineOrchestrator(pipeline) # setup control plane