Skip to content

Commit

Permalink
outsource more logic to controller
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jul 30, 2024
1 parent 5a360c1 commit c162c73
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 93 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import queue
from typing import Any, Dict, TypedDict
from typing import Any, TypedDict

from llama_agents import HumanService, ServiceComponent
from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue
Expand All @@ -28,8 +28,8 @@ class HumanRequest(TypedDict):


# # human in the loop function
human_input_request_queue = queue.Queue()
human_input_result_queue = queue.Queue()
human_input_request_queue: queue.Queue[HumanRequest] = queue.Queue()
human_input_result_queue: queue.Queue[str] = queue.Queue()


async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str:
Expand Down
33 changes: 0 additions & 33 deletions snowflake_cybersyn_demo/apps/async_list.py

This file was deleted.

65 changes: 63 additions & 2 deletions snowflake_cybersyn_demo/apps/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

import streamlit as st
from llama_agents import LlamaAgentsClient
from llama_agents.types import TaskResult
from llama_index.core.llms import ChatMessage, ChatResponseGen

from snowflake_cybersyn_demo.additional_services.human_in_the_loop import (
HumanRequest,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -41,14 +46,14 @@ def __init__(
self._step_interval = 0.5
self._timeout = 60

def _llama_index_stream_wrapper(
def llama_index_stream_wrapper(
self,
llama_index_stream: ChatResponseGen,
) -> Generator[str, Any, Any]:
for chunk in llama_index_stream:
yield chunk.delta

def _handle_task_submission(self) -> None:
def handle_task_submission(self) -> None:
"""Handle the user submitted message. Clear task submission box, and
add the new task to the submitted list.
"""
Expand All @@ -73,3 +78,59 @@ def _handle_task_submission(self) -> None:
st.session_state.submitted_tasks.append(task)
logger.info("Added task to submitted queue")
st.session_state.task_input = ""

def update_associated_task_to_completed_status(
self,
task_res: TaskResult,
) -> None:
"""
Update task status to completed for received task result.
Update session_state lists as well.
"""
try:
task_list = st.session_state.get("submitted_tasks")
print(f"submitted tasks: {task_list}")
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == task_res.task_id
)
task.status = TaskStatus.COMPLETED
task.chat_history.append(
ChatMessage(role="assistant", content=task_res.result)
)
del task_list[ix]
st.session_state.submitted_tasks = task_list
st.session_state.completed_tasks.append(task)
logger.info("updated submitted and completed tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")

def update_associated_task_to_human_required_status(
self,
human_req: HumanRequest,
) -> None:
"""
Update task status to human_required for received task request.
Update session_state lists as well.
"""
try:
task_list = st.session_state.get("submitted_tasks")
print(f"submitted tasks: {task_list}")
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == human_req["task_id"]
)
task.status = TaskStatus.HUMAN_REQUIRED
task.chat_history.append(
ChatMessage(role="assistant", content=human_req["prompt"])
)
del task_list[ix]
st.session_state.submitted_tasks = task_list
st.session_state.human_required_tasks.append(task)
logger.info("updated submitted and human required tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")
92 changes: 37 additions & 55 deletions snowflake_cybersyn_demo/apps/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
import queue
import threading
import time
from typing import Optional
from typing import Optional, Tuple

import pandas as pd
import streamlit as st
from llama_agents.types import TaskResult
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI

from snowflake_cybersyn_demo.additional_services.human_in_the_loop import (
HumanRequest,
HumanService,
)
from snowflake_cybersyn_demo.apps.controller import (
Controller,
TaskStatus,
)
from snowflake_cybersyn_demo.apps.controller import Controller
from snowflake_cybersyn_demo.apps.final_task_consumer import FinalTaskConsumer

logger = logging.getLogger(__name__)
Expand All @@ -31,28 +29,35 @@


@st.cache_resource
def startup():
def startup() -> (
Tuple[
Controller,
queue.Queue[TaskResult],
FinalTaskConsumer,
queue.Queue[HumanRequest],
queue.Queue[str],
]
):
from snowflake_cybersyn_demo.additional_services.human_in_the_loop import (
human_input_request_queue,
human_input_result_queue,
human_service,
message_queue,
)

completed_tasks_queue = queue.Queue()
controller = Controller(
control_plane_host=control_plane_host,
control_plane_port=control_plane_port,
)

async def start_consuming_human_tasks(human_service: HumanService):
human_task_consuming_callable = await message_queue.register_consumer(
human_service.as_consumer()
async def start_consuming_human_tasks(hs: HumanService) -> None:
consuming_callable = await message_queue.register_consumer(
hs.as_consumer()
)

ht_task = asyncio.create_task(human_task_consuming_callable())
ht_task = asyncio.create_task(consuming_callable()) # noqa: F841

launch_task = asyncio.create_task(human_service.processing_loop())
pl_task = asyncio.create_task(hs.processing_loop()) # noqa: F841

await asyncio.Future()

Expand All @@ -64,12 +69,15 @@ async def start_consuming_human_tasks(human_service: HumanService):
)
hr_thread.start()

completed_tasks_queue: queue.Queue[TaskResult] = queue.Queue()
final_task_consumer = FinalTaskConsumer(
message_queue=message_queue,
completed_tasks_queue=completed_tasks_queue,
)

async def start_consuming_finalized_tasks(final_task_consumer):
async def start_consuming_finalized_tasks(
final_task_consumer: FinalTaskConsumer,
) -> None:
final_task_consuming_callable = (
await final_task_consumer.register_to_message_queue()
)
Expand Down Expand Up @@ -129,7 +137,7 @@ async def start_consuming_finalized_tasks(final_task_consumer):
"Task input",
placeholder="Enter a task input.",
key="task_input",
on_change=controller._handle_task_submission,
on_change=controller.handle_task_submission,
)

with right:
Expand All @@ -152,12 +160,16 @@ async def start_consuming_finalized_tasks(final_task_consumer):
for m in st.session_state.messages
]
)
response = st.write_stream(controller._llama_index_stream_wrapper(stream))
st.session_state.messages.append({"role": "assistant", "content": response})
response = st.write_stream(
controller.llama_index_stream_wrapper(stream)
)
st.session_state.messages.append(
{"role": "assistant", "content": response}
)


@st.experimental_fragment(run_every="30s")
def task_df():
def task_df() -> None:
st.text("Task Status")
st.button("Refresh")
tasks = (
Expand All @@ -183,7 +195,7 @@ def task_df():


@st.experimental_fragment(run_every=5)
def process_completed_tasks(completed_queue: queue.Queue):
def process_completed_tasks(completed_queue: queue.Queue) -> None:
task_res: Optional[TaskResult] = None
try:
task_res = completed_queue.get_nowait()
Expand All @@ -192,24 +204,9 @@ def process_completed_tasks(completed_queue: queue.Queue):
logger.info("task result queue is empty.")

if task_res:
try:
task_list = st.session_state.get("submitted_tasks")
print(f"submitted tasks: {task_list}")
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == task_res.task_id
)
task.status = TaskStatus.COMPLETED
task.chat_history.append(
ChatMessage(role="assistant", content=task_res.result)
)
del task_list[ix]
st.session_state.submitted_tasks = task_list
st.session_state.completed_tasks.append(task)
logger.info("updated submitted and completed tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")
controller.update_associated_task_to_completed_status(
task_res=task_res
)


process_completed_tasks(completed_queue=completed_tasks_queue)
Expand All @@ -218,7 +215,7 @@ def process_completed_tasks(completed_queue: queue.Queue):
@st.experimental_fragment(run_every=5)
def process_human_input_requests(
human_requests_queue: queue.Queue[HumanRequest],
):
) -> None:
human_req: Optional[HumanRequest] = None
try:
human_req = human_requests_queue.get_nowait()
Expand All @@ -227,24 +224,9 @@ def process_human_input_requests(
logger.info("human request queue is empty.")

if human_req:
try:
task_list = st.session_state.get("submitted_tasks")
print(f"submitted tasks: {task_list}")
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == human_req["task_id"]
)
task.status = TaskStatus.COMPLETED
task.chat_history.append(
ChatMessage(role="assistant", content=human_req["prompt"])
)
del task_list[ix]
st.session_state.submitted_tasks = task_list
st.session_state.human_required_tasks.append(task)
logger.info("updated submitted and human required tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")
controller.update_associated_task_to_human_required_status(
human_req=human_req
)


process_human_input_requests(human_requests_queue=human_input_request_queue)

0 comments on commit c162c73

Please sign in to comment.