diff --git a/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py b/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py index bd87b4f..b26af32 100644 --- a/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py +++ b/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py @@ -1,7 +1,7 @@ import asyncio import logging import queue -from typing import Any, Dict, TypedDict +from typing import Any, TypedDict from llama_agents import HumanService, ServiceComponent from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue @@ -28,8 +28,8 @@ class HumanRequest(TypedDict): # # human in the loop function -human_input_request_queue = queue.Queue() -human_input_result_queue = queue.Queue() +human_input_request_queue: queue.Queue[HumanRequest] = queue.Queue() +human_input_result_queue: queue.Queue[str] = queue.Queue() async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str: diff --git a/snowflake_cybersyn_demo/apps/async_list.py b/snowflake_cybersyn_demo/apps/async_list.py deleted file mode 100644 index 95ab012..0000000 --- a/snowflake_cybersyn_demo/apps/async_list.py +++ /dev/null @@ -1,33 +0,0 @@ -from asyncio import Lock - - -class AsyncSafeList: - def __init__(self): - self._list = list() - self._lock = Lock() - - def __aiter__(self): - return aiter(self._list) - - def __iter__(self): - return iter(self._list) - - async def append(self, value): - async with self._lock: - self._list.append(value) - - async def pop(self): - async with self._lock: - return self._list.pop() - - async def delete(self, index): - async with self._lock: - del self._list[index] - - async def get(self, index): - async with self._lock: - return self._list[index] - - async def length(self): - async with self._lock: - return len(self._list) diff --git a/snowflake_cybersyn_demo/apps/controller.py b/snowflake_cybersyn_demo/apps/controller.py index 3c4ae55..b43e3bf 100644 --- a/snowflake_cybersyn_demo/apps/controller.py +++ b/snowflake_cybersyn_demo/apps/controller.py @@ -5,8 +5,13 @@ import streamlit as st from llama_agents import LlamaAgentsClient +from llama_agents.types import TaskResult from llama_index.core.llms import ChatMessage, ChatResponseGen +from snowflake_cybersyn_demo.additional_services.human_in_the_loop import ( + HumanRequest, +) + logger = logging.getLogger(__name__) @@ -41,14 +46,14 @@ def __init__( self._step_interval = 0.5 self._timeout = 60 - def _llama_index_stream_wrapper( + def llama_index_stream_wrapper( self, llama_index_stream: ChatResponseGen, ) -> Generator[str, Any, Any]: for chunk in llama_index_stream: yield chunk.delta - def _handle_task_submission(self) -> None: + def handle_task_submission(self) -> None: """Handle the user submitted message. Clear task submission box, and add the new task to the submitted list. """ @@ -73,3 +78,59 @@ def _handle_task_submission(self) -> None: st.session_state.submitted_tasks.append(task) logger.info("Added task to submitted queue") st.session_state.task_input = "" + + def update_associated_task_to_completed_status( + self, + task_res: TaskResult, + ) -> None: + """ + Update task status to completed for received task result. + + Update session_state lists as well. + """ + try: + task_list = st.session_state.get("submitted_tasks") + print(f"submitted tasks: {task_list}") + ix, task = next( + (ix, t) + for ix, t in enumerate(task_list) + if t.task_id == task_res.task_id + ) + task.status = TaskStatus.COMPLETED + task.chat_history.append( + ChatMessage(role="assistant", content=task_res.result) + ) + del task_list[ix] + st.session_state.submitted_tasks = task_list + st.session_state.completed_tasks.append(task) + logger.info("updated submitted and completed tasks list.") + except StopIteration: + raise ValueError("Cannot find task in list of tasks.") + + def update_associated_task_to_human_required_status( + self, + human_req: HumanRequest, + ) -> None: + """ + Update task status to human_required for received task request. + + Update session_state lists as well. + """ + try: + task_list = st.session_state.get("submitted_tasks") + print(f"submitted tasks: {task_list}") + ix, task = next( + (ix, t) + for ix, t in enumerate(task_list) + if t.task_id == human_req["task_id"] + ) + task.status = TaskStatus.HUMAN_REQUIRED + task.chat_history.append( + ChatMessage(role="assistant", content=human_req["prompt"]) + ) + del task_list[ix] + st.session_state.submitted_tasks = task_list + st.session_state.human_required_tasks.append(task) + logger.info("updated submitted and human required tasks list.") + except StopIteration: + raise ValueError("Cannot find task in list of tasks.") diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index efe6c38..64cdce3 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -3,10 +3,11 @@ import queue import threading import time -from typing import Optional +from typing import Optional, Tuple import pandas as pd import streamlit as st +from llama_agents.types import TaskResult from llama_index.core.llms import ChatMessage from llama_index.llms.openai import OpenAI @@ -14,10 +15,7 @@ HumanRequest, HumanService, ) -from snowflake_cybersyn_demo.apps.controller import ( - Controller, - TaskStatus, -) +from snowflake_cybersyn_demo.apps.controller import Controller from snowflake_cybersyn_demo.apps.final_task_consumer import FinalTaskConsumer logger = logging.getLogger(__name__) @@ -31,7 +29,15 @@ @st.cache_resource -def startup(): +def startup() -> ( + Tuple[ + Controller, + queue.Queue[TaskResult], + FinalTaskConsumer, + queue.Queue[HumanRequest], + queue.Queue[str], + ] +): from snowflake_cybersyn_demo.additional_services.human_in_the_loop import ( human_input_request_queue, human_input_result_queue, @@ -39,20 +45,19 @@ def startup(): message_queue, ) - completed_tasks_queue = queue.Queue() controller = Controller( control_plane_host=control_plane_host, control_plane_port=control_plane_port, ) - async def start_consuming_human_tasks(human_service: HumanService): - human_task_consuming_callable = await message_queue.register_consumer( - human_service.as_consumer() + async def start_consuming_human_tasks(hs: HumanService) -> None: + consuming_callable = await message_queue.register_consumer( + hs.as_consumer() ) - ht_task = asyncio.create_task(human_task_consuming_callable()) + ht_task = asyncio.create_task(consuming_callable()) # noqa: F841 - launch_task = asyncio.create_task(human_service.processing_loop()) + pl_task = asyncio.create_task(hs.processing_loop()) # noqa: F841 await asyncio.Future() @@ -64,12 +69,15 @@ async def start_consuming_human_tasks(human_service: HumanService): ) hr_thread.start() + completed_tasks_queue: queue.Queue[TaskResult] = queue.Queue() final_task_consumer = FinalTaskConsumer( message_queue=message_queue, completed_tasks_queue=completed_tasks_queue, ) - async def start_consuming_finalized_tasks(final_task_consumer): + async def start_consuming_finalized_tasks( + final_task_consumer: FinalTaskConsumer, + ) -> None: final_task_consuming_callable = ( await final_task_consumer.register_to_message_queue() ) @@ -129,7 +137,7 @@ async def start_consuming_finalized_tasks(final_task_consumer): "Task input", placeholder="Enter a task input.", key="task_input", - on_change=controller._handle_task_submission, + on_change=controller.handle_task_submission, ) with right: @@ -152,12 +160,16 @@ async def start_consuming_finalized_tasks(final_task_consumer): for m in st.session_state.messages ] ) - response = st.write_stream(controller._llama_index_stream_wrapper(stream)) - st.session_state.messages.append({"role": "assistant", "content": response}) + response = st.write_stream( + controller.llama_index_stream_wrapper(stream) + ) + st.session_state.messages.append( + {"role": "assistant", "content": response} + ) @st.experimental_fragment(run_every="30s") -def task_df(): +def task_df() -> None: st.text("Task Status") st.button("Refresh") tasks = ( @@ -183,7 +195,7 @@ def task_df(): @st.experimental_fragment(run_every=5) -def process_completed_tasks(completed_queue: queue.Queue): +def process_completed_tasks(completed_queue: queue.Queue) -> None: task_res: Optional[TaskResult] = None try: task_res = completed_queue.get_nowait() @@ -192,24 +204,9 @@ def process_completed_tasks(completed_queue: queue.Queue): logger.info("task result queue is empty.") if task_res: - try: - task_list = st.session_state.get("submitted_tasks") - print(f"submitted tasks: {task_list}") - ix, task = next( - (ix, t) - for ix, t in enumerate(task_list) - if t.task_id == task_res.task_id - ) - task.status = TaskStatus.COMPLETED - task.chat_history.append( - ChatMessage(role="assistant", content=task_res.result) - ) - del task_list[ix] - st.session_state.submitted_tasks = task_list - st.session_state.completed_tasks.append(task) - logger.info("updated submitted and completed tasks list.") - except StopIteration: - raise ValueError("Cannot find task in list of tasks.") + controller.update_associated_task_to_completed_status( + task_res=task_res + ) process_completed_tasks(completed_queue=completed_tasks_queue) @@ -218,7 +215,7 @@ def process_completed_tasks(completed_queue: queue.Queue): @st.experimental_fragment(run_every=5) def process_human_input_requests( human_requests_queue: queue.Queue[HumanRequest], -): +) -> None: human_req: Optional[HumanRequest] = None try: human_req = human_requests_queue.get_nowait() @@ -227,24 +224,9 @@ def process_human_input_requests( logger.info("human request queue is empty.") if human_req: - try: - task_list = st.session_state.get("submitted_tasks") - print(f"submitted tasks: {task_list}") - ix, task = next( - (ix, t) - for ix, t in enumerate(task_list) - if t.task_id == human_req["task_id"] - ) - task.status = TaskStatus.COMPLETED - task.chat_history.append( - ChatMessage(role="assistant", content=human_req["prompt"]) - ) - del task_list[ix] - st.session_state.submitted_tasks = task_list - st.session_state.human_required_tasks.append(task) - logger.info("updated submitted and human required tasks list.") - except StopIteration: - raise ValueError("Cannot find task in list of tasks.") + controller.update_associated_task_to_human_required_status( + human_req=human_req + ) process_human_input_requests(human_requests_queue=human_input_request_queue)