diff --git a/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py b/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py index 0059015..0148d39 100644 --- a/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py +++ b/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py @@ -37,7 +37,15 @@ async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str: # poll until human answer is stored async def _poll_for_human_input_result() -> str: - return human_input_result_queue.get() + human_input = None + while human_input is None: + try: + human_input = human_input_result_queue.get_nowait() + except queue.Empty: + human_input = None + await asyncio.sleep(0.1) + logger.info("human input recieved") + return human_input try: human_input = await asyncio.wait_for( diff --git a/snowflake_cybersyn_demo/apps/controller.py b/snowflake_cybersyn_demo/apps/controller.py index 1407032..a6ad2d4 100644 --- a/snowflake_cybersyn_demo/apps/controller.py +++ b/snowflake_cybersyn_demo/apps/controller.py @@ -1,4 +1,5 @@ import logging +import queue import time from dataclasses import dataclass, field from enum import Enum @@ -78,6 +79,18 @@ def handle_task_submission(self) -> None: st.session_state.current_task = task st.session_state.task_input = "" + def get_human_input_handler( + self, human_input_result_queue: queue.Queue + ) -> Callable: + def human_input_handler() -> None: + human_input = st.session_state.human_input + if human_input == "": + return + human_input_result_queue.put_nowait(human_input) + logger.info("pushed human input to human input result queue.") + + return human_input_handler + def update_associated_task_to_completed_status( self, task_res: TaskResult, @@ -87,24 +100,40 @@ def update_associated_task_to_completed_status( Update session_state lists as well. """ - try: - task_list = st.session_state.get("submitted_tasks") - print(f"submitted tasks: {task_list}") - ix, task = next( - (ix, t) - for ix, t in enumerate(task_list) - if t.task_id == task_res.task_id - ) - task.status = TaskStatus.COMPLETED - task.history.append( - ChatMessage(role="assistant", content=task_res.result) + + def remove_task_from_list( + task_list: List[TaskModel], + ) -> List[TaskModel]: + try: + ix, task = next( + (ix, t) + for ix, t in enumerate(task_list) + if t.task_id == task_res.task_id + ) + task.status = TaskStatus.COMPLETED + task.history.append( + ChatMessage(role="assistant", content=task_res.result) + ) + del task_list[ix] + st.session_state.completed_tasks.append(task) + logger.info("updated submitted and completed tasks list.") + except StopIteration: + raise ValueError("Cannot find task in list of tasks.") + return task_list + + submitted_tasks = st.session_state.get("submitted_tasks") + human_required_tasks = st.session_state.get("human_required_tasks") + + if task_res.task_id in [t.task_id for t in submitted_tasks]: + updated_task_list = remove_task_from_list(submitted_tasks) + st.session_state.submitted_tasks = updated_task_list + elif task_res.task_id in [t.task_id for t in human_required_tasks]: + updated_task_list = remove_task_from_list(human_required_tasks) + st.session_state.human_required_tasks = updated_task_list + else: + raise ValueError( + "Completed task not in submitted or human_required lists." ) - del task_list[ix] - st.session_state.submitted_tasks = task_list - st.session_state.completed_tasks.append(task) - logger.info("updated submitted and completed tasks list.") - except StopIteration: - raise ValueError("Cannot find task in list of tasks.") def update_associated_task_to_human_required_status( self, diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index ae02f21..11caab3 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -140,6 +140,8 @@ async def start_consuming_finalized_tasks( st.session_state.messages = [] if "current_task" not in st.session_state: st.session_state.current_task = None +if "human_input" not in st.session_state: + st.session_state.human_input = "" left, right = st.columns([1, 2], vertical_alignment="top") @@ -154,39 +156,46 @@ async def start_consuming_finalized_tasks( def chat_window() -> None: - with st.sidebar: - - @st.experimental_fragment(run_every="5s") - def show_chat_window() -> None: - messages_container = st.container(height=500) - with messages_container: - if st.session_state.current_task: - messages = [ - m.dict() for m in st.session_state.current_task.history - ] - for message in messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) - else: - st.empty() - - show_chat_window() - - if _ := st.chat_input("What is up?"): - pass - # st.session_state.messages.append({"role": "user", "content": prompt}) - # with st.chat_message("user"): - # st.markdown(prompt) - - # with st.chat_message("assistant"): - # stream = llm.stream_chat( - # messages=[ - # ChatMessage(role=m["role"], content=m["content"]) - # for m in st.session_state.messages - # ] - # ) - # response = st.write_stream(controller.llama_index_stream_wrapper(stream)) - # st.session_state.messages.append({"role": "assistant", "content": response}) + pass + # with st.sidebar: + + # @st.experimental_fragment(run_every="5s") + # def show_chat_window() -> None: + # messages_container = st.container(height=500) + # with messages_container: + # if st.session_state.current_task: + # messages = [m.dict() for m in st.session_state.current_task.history] + # for message in messages: + # with st.chat_message(message["role"]): + # st.markdown(message["content"]) + # else: + # st.empty() + + # show_chat_window() + + # if human_input := st.chat_input("Provide human input."): + # if st.session_state.current_task: + # st.session_state.current_task.history.append( + # ChatMessage(role="user", content=human_input) + # ) + # human_input_result_queue.put(human_input) + # time.sleep(1) + # logger.info("pushed human input to human input result queue.") + + # logger.info(f"HUMAN INPUT: {human_input}") + # st.session_state.messages.append({"role": "user", "content": prompt}) + # with st.chat_message("user"): + # st.markdown(prompt) + + # with st.chat_message("assistant"): + # stream = llm.stream_chat( + # messages=[ + # ChatMessage(role=m["role"], content=m["content"]) + # for m in st.session_state.messages + # ] + # ) + # response = st.write_stream(controller.llama_index_stream_wrapper(stream)) + # st.session_state.messages.append({"role": "assistant", "content": response}) @st.experimental_fragment(run_every="30s") @@ -210,10 +219,16 @@ def task_df() -> None: + ["human_required"] * len(st.session_state.human_required_tasks) + ["completed"] * len(st.session_state.completed_tasks) ) - data = {"task_id": task_ids, "input": tasks, "status": status} + + data = { + "task_id": task_ids, + "input": tasks, + "status": status, + } + logger.info(f"data: {data}") df = pd.DataFrame(data) - st.dataframe( + event = st.dataframe( df, hide_index=True, selection_mode="single-row", @@ -222,6 +237,22 @@ def task_df() -> None: key="task_df", ) + popover_enabled = ( + len(event.selection["rows"]) > 0 + and st.session_state.current_task.status == "human_required" + ) + with st.popover("Human Input", disabled=not popover_enabled): + if popover_enabled: + human_prompt = st.session_state.current_task.history[-1].content + st.markdown(human_prompt) + st.text_input( + "Provide human input", + key="human_input", + on_change=controller.get_human_input_handler( + human_input_result_queue + ), + ) + task_df()