Skip to content

Commit

Permalink
local launcher and app controller
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jul 25, 2024
1 parent 6110a7f commit 44885c1
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 31 deletions.
88 changes: 67 additions & 21 deletions snowflake_cybersyn_demo/apps/controller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import asyncio
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Generator, List, Optional

import streamlit as st
from llama_agents import LlamaAgentsClient
from llama_agents import (
CallableMessageConsumer,
LlamaAgentsClient,
QueueMessage,
)
from llama_agents.types import ActionTypes, TaskResult
from llama_index.core.llms import ChatMessage, ChatResponseGen

logger = logging.getLogger(__name__)


class TaskStatus(str, Enum):
HUMAN_REQUIRED = "human_required"
Expand All @@ -23,8 +32,47 @@ class TaskModel:


class Controller:
def __init__(self, llama_agents_client: Optional[LlamaAgentsClient]):
self._client = llama_agents_client
def __init__(
self,
human_in_loop_queue: asyncio.Queue,
human_in_loop_result_queue: asyncio.Queue,
control_plane_host: str = "127.0.0.1",
control_plane_port: Optional[int] = 8000,
):
self.human_in_loop_queue = human_in_loop_queue
self.human_in_loop_result_queue = human_in_loop_result_queue
self._client = LlamaAgentsClient(
control_plane_url=(
f"http://{control_plane_host}:{control_plane_port}"
if control_plane_port
else f"http://{control_plane_host}"
)
)
self._step_interval = 0.5
self._timeout = 60
self._raise_timeout = False
self._human_in_the_loop_task: Optional[str] = None
self._human_input: Optional[str] = None
self._final_task_consumer = CallableMessageConsumer(
message_type="human", handler=self._process_completed_task_messages
)
self._completed_tasks_queue: asyncio.Queue[
TaskResult
] = asyncio.Queue()

async def _process_completed_task_messages(
self, message: QueueMessage, **kwargs: Any
) -> None:
"""Consumer of completed tasks.
By default control plane sends to message consumer of type "human".
The process message logic contained here simply puts the TaskResult into
a queue that is continuosly via a gr.Timer().
"""
if message.action == ActionTypes.COMPLETED_TASK:
task_res = TaskResult(**message.data)
await self._completed_tasks_queue.put(task_res)
logger.info("Added task result to queue")

def _llama_index_stream_wrapper(
self,
Expand All @@ -33,27 +81,25 @@ def _llama_index_stream_wrapper(
for chunk in llama_index_stream:
yield chunk.delta

def _handle_task_submission(
self, llama_agents_client: LlamaAgentsClient
) -> None:
def _handle_task_submission(self) -> None:
"""Handle the user submitted message. Clear task submission box, and
add the new task to the submitted list.
"""

# create new task and store in state
# task_input = st.session_state.task_input
# task_id = self._client.create_task(task_input)
# task = TaskModel(
# task_id=task_id,
# input=task_input,
# chat_history=[
# ChatMessage(role="user", content=task_input),
# ChatMessage(
# role="assistant",
# content=f"Successfully submitted task: {task_id}.",
# ),
# ],
# status=TaskStatus.SUBMITTED,
# )
task_input = st.session_state.task_input
task_id = self._client.create_task(task_input)
task = TaskModel(
task_id=task_id,
input=task_input,
chat_history=[
ChatMessage(role="user", content=task_input),
ChatMessage(
role="assistant",
content=f"Successfully submitted task: {task_id}.",
),
],
status=TaskStatus.SUBMITTED,
)
st.session_state.submitted_pills.append(st.session_state.task_input)
# st.session_state.tasks.append(task)
st.session_state.tasks.append(task)
21 changes: 11 additions & 10 deletions snowflake_cybersyn_demo/apps/streamlit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
from typing import Dict

import pandas as pd
import streamlit as st
from llama_index.core.llms import ChatMessage
Expand All @@ -8,15 +11,14 @@
llm = OpenAI(model="gpt-4o-mini")
control_plane_host = "0.0.0.0"
control_plane_port = 8001
# llama_agents_client = LlamaAgentsClient(
# control_plane_url=(
# f"http://{control_plane_host}:{control_plane_port}"
# if control_plane_port
# else f"http://{control_plane_host}"
# )
# )
llama_agents_client = None
controller = Controller(llama_agents_client)
human_input_request_queue: asyncio.Queue[Dict[str, str]] = asyncio.Queue()
human_input_result_queue: asyncio.Queue[str] = asyncio.Queue()
controller = Controller(
human_in_loop_queue=human_input_request_queue,
human_in_loop_result_queue=human_input_result_queue,
control_plane_host=control_plane_host,
control_plane_port=control_plane_port,
)

### App
st.set_page_config(layout="wide")
Expand All @@ -39,7 +41,6 @@
placeholder="Enter a task input.",
key="task_input",
on_change=controller._handle_task_submission,
args=(llama_agents_client,),
)

with right:
Expand Down
18 changes: 18 additions & 0 deletions snowflake_cybersyn_demo/local_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from human_in_the_loop.additional_services.human_in_the_loop import (
human_service,
)
from human_in_the_loop.agent_services.funny_agent import agent_server
from human_in_the_loop.core_services.control_plane import control_plane
from human_in_the_loop.core_services.message_queue import message_queue
from llama_agents import ServerLauncher

# launch it
launcher = ServerLauncher(
[agent_server, human_service],
control_plane,
message_queue,
)


if __name__ == "__main__":
launcher.launch_servers()

0 comments on commit 44885c1

Please sign in to comment.