From 44885c19ba5e3241de318da5b6e2dbfdf6a0780a Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Thu, 25 Jul 2024 11:32:34 -0400 Subject: [PATCH] local launcher and app controller --- snowflake_cybersyn_demo/apps/controller.py | 88 ++++++++++++++++------ snowflake_cybersyn_demo/apps/streamlit.py | 21 +++--- snowflake_cybersyn_demo/local_launcher.py | 18 +++++ 3 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 snowflake_cybersyn_demo/local_launcher.py diff --git a/snowflake_cybersyn_demo/apps/controller.py b/snowflake_cybersyn_demo/apps/controller.py index 98873c1..204240f 100644 --- a/snowflake_cybersyn_demo/apps/controller.py +++ b/snowflake_cybersyn_demo/apps/controller.py @@ -1,11 +1,20 @@ +import asyncio +import logging from dataclasses import dataclass, field from enum import Enum from typing import Any, Generator, List, Optional import streamlit as st -from llama_agents import LlamaAgentsClient +from llama_agents import ( + CallableMessageConsumer, + LlamaAgentsClient, + QueueMessage, +) +from llama_agents.types import ActionTypes, TaskResult from llama_index.core.llms import ChatMessage, ChatResponseGen +logger = logging.getLogger(__name__) + class TaskStatus(str, Enum): HUMAN_REQUIRED = "human_required" @@ -23,8 +32,47 @@ class TaskModel: class Controller: - def __init__(self, llama_agents_client: Optional[LlamaAgentsClient]): - self._client = llama_agents_client + def __init__( + self, + human_in_loop_queue: asyncio.Queue, + human_in_loop_result_queue: asyncio.Queue, + control_plane_host: str = "127.0.0.1", + control_plane_port: Optional[int] = 8000, + ): + self.human_in_loop_queue = human_in_loop_queue + self.human_in_loop_result_queue = human_in_loop_result_queue + self._client = LlamaAgentsClient( + control_plane_url=( + f"http://{control_plane_host}:{control_plane_port}" + if control_plane_port + else f"http://{control_plane_host}" + ) + ) + self._step_interval = 0.5 + self._timeout = 60 + self._raise_timeout = False + self._human_in_the_loop_task: Optional[str] = None + self._human_input: Optional[str] = None + self._final_task_consumer = CallableMessageConsumer( + message_type="human", handler=self._process_completed_task_messages + ) + self._completed_tasks_queue: asyncio.Queue[ + TaskResult + ] = asyncio.Queue() + + async def _process_completed_task_messages( + self, message: QueueMessage, **kwargs: Any + ) -> None: + """Consumer of completed tasks. + + By default control plane sends to message consumer of type "human". + The process message logic contained here simply puts the TaskResult into + a queue that is continuosly via a gr.Timer(). + """ + if message.action == ActionTypes.COMPLETED_TASK: + task_res = TaskResult(**message.data) + await self._completed_tasks_queue.put(task_res) + logger.info("Added task result to queue") def _llama_index_stream_wrapper( self, @@ -33,27 +81,25 @@ def _llama_index_stream_wrapper( for chunk in llama_index_stream: yield chunk.delta - def _handle_task_submission( - self, llama_agents_client: LlamaAgentsClient - ) -> None: + def _handle_task_submission(self) -> None: """Handle the user submitted message. Clear task submission box, and add the new task to the submitted list. """ # create new task and store in state - # task_input = st.session_state.task_input - # task_id = self._client.create_task(task_input) - # task = TaskModel( - # task_id=task_id, - # input=task_input, - # chat_history=[ - # ChatMessage(role="user", content=task_input), - # ChatMessage( - # role="assistant", - # content=f"Successfully submitted task: {task_id}.", - # ), - # ], - # status=TaskStatus.SUBMITTED, - # ) + task_input = st.session_state.task_input + task_id = self._client.create_task(task_input) + task = TaskModel( + task_id=task_id, + input=task_input, + chat_history=[ + ChatMessage(role="user", content=task_input), + ChatMessage( + role="assistant", + content=f"Successfully submitted task: {task_id}.", + ), + ], + status=TaskStatus.SUBMITTED, + ) st.session_state.submitted_pills.append(st.session_state.task_input) - # st.session_state.tasks.append(task) + st.session_state.tasks.append(task) diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index 3e36f26..f37ad17 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -1,3 +1,6 @@ +import asyncio +from typing import Dict + import pandas as pd import streamlit as st from llama_index.core.llms import ChatMessage @@ -8,15 +11,14 @@ llm = OpenAI(model="gpt-4o-mini") control_plane_host = "0.0.0.0" control_plane_port = 8001 -# llama_agents_client = LlamaAgentsClient( -# control_plane_url=( -# f"http://{control_plane_host}:{control_plane_port}" -# if control_plane_port -# else f"http://{control_plane_host}" -# ) -# ) -llama_agents_client = None -controller = Controller(llama_agents_client) +human_input_request_queue: asyncio.Queue[Dict[str, str]] = asyncio.Queue() +human_input_result_queue: asyncio.Queue[str] = asyncio.Queue() +controller = Controller( + human_in_loop_queue=human_input_request_queue, + human_in_loop_result_queue=human_input_result_queue, + control_plane_host=control_plane_host, + control_plane_port=control_plane_port, +) ### App st.set_page_config(layout="wide") @@ -39,7 +41,6 @@ placeholder="Enter a task input.", key="task_input", on_change=controller._handle_task_submission, - args=(llama_agents_client,), ) with right: diff --git a/snowflake_cybersyn_demo/local_launcher.py b/snowflake_cybersyn_demo/local_launcher.py new file mode 100644 index 0000000..46a02a4 --- /dev/null +++ b/snowflake_cybersyn_demo/local_launcher.py @@ -0,0 +1,18 @@ +from human_in_the_loop.additional_services.human_in_the_loop import ( + human_service, +) +from human_in_the_loop.agent_services.funny_agent import agent_server +from human_in_the_loop.core_services.control_plane import control_plane +from human_in_the_loop.core_services.message_queue import message_queue +from llama_agents import ServerLauncher + +# launch it +launcher = ServerLauncher( + [agent_server, human_service], + control_plane, + message_queue, +) + + +if __name__ == "__main__": + launcher.launch_servers()