From 878ff60da6f6f6d3dff4b383da752076b4b4c5be Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Fri, 26 Jul 2024 02:56:58 -0400 Subject: [PATCH] argh --- snowflake_cybersyn_demo/apps/controller.py | 4 +- snowflake_cybersyn_demo/apps/streamlit.py | 63 +++++++++++----------- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/snowflake_cybersyn_demo/apps/controller.py b/snowflake_cybersyn_demo/apps/controller.py index fbb8145..f3376ce 100644 --- a/snowflake_cybersyn_demo/apps/controller.py +++ b/snowflake_cybersyn_demo/apps/controller.py @@ -63,9 +63,7 @@ def __init__( self._final_task_consumer = CallableMessageConsumer( message_type="human", handler=self._process_completed_task_messages ) - self._completed_tasks_queue: asyncio.Queue[ - TaskResult - ] = asyncio.Queue() + self._completed_tasks_queue: asyncio.Queue[TaskResult] = asyncio.Queue() async def _process_completed_task_messages( self, message: QueueMessage, **kwargs: Any diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index 3777739..0a2f05d 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -74,12 +74,8 @@ for m in st.session_state.messages ] ) - response = st.write_stream( - controller._llama_index_stream_wrapper(stream) - ) - st.session_state.messages.append( - {"role": "assistant", "content": response} - ) + response = st.write_stream(controller._llama_index_stream_wrapper(stream)) + st.session_state.messages.append({"role": "assistant", "content": response}) bottom = st.container() with bottom: @@ -168,14 +164,10 @@ def remove_from_list_closure( over to the completed list. """ ix, task = next( - (ix, t) - for ix, t in enumerate(task_list) - if t.task_id == task_res.task_id + (ix, t) for ix, t in enumerate(task_list) if t.task_id == task_res.task_id ) task.status = TaskStatus.COMPLETED - task.chat_history.append( - ChatMessage(role="assistant", content=task_res.result) - ) + task.chat_history.append(ChatMessage(role="assistant", content=task_res.result)) del task_list[ix] st.session_state.completed_tasks.append(task) @@ -188,9 +180,7 @@ def remove_from_list_closure( try: task_res: TaskResult = controller._completed_tasks_queue.get_nowait() logger.info("got new completed task result") - if task_res.task_id in [ - t.task_id for t in st.session_state.submitted_tasks - ]: + if task_res.task_id in [t.task_id for t in st.session_state.submitted_tasks]: remove_from_list_closure( st.session_state.submitted_tasks, TaskStatus.SUBMITTED ) @@ -202,9 +192,7 @@ def remove_from_list_closure( TaskStatus.HUMAN_REQUIRED, ) else: - raise ValueError( - "Completed task not in submitted or human_needed lists." - ) + raise ValueError("Completed task not in submitted or human_needed lists.") except asyncio.QueueEmpty: logger.info("completed task queue is empty.") pass @@ -214,23 +202,34 @@ def remove_from_list_closure( continuously_check_for_completed_tasks() -async def launch() -> None: - start_consuming_callable = ( - await controller._human_service.message_queue.register_consumer( - controller._human_service.as_consumer() +@st.cache_resource +def get_consuming_callables() -> None: + async def launch() -> None: + start_consuming_callable = ( + await controller._human_service.message_queue.register_consumer( + controller._human_service.as_consumer() + ) ) - ) - h_task = asyncio.create_task(start_consuming_callable()) # noqa: F841 - final_task_consuming_callable = ( - await controller._human_service.message_queue.register_consumer( - controller._final_task_consumer + final_task_consuming_callable = ( + await controller._human_service.message_queue.register_consumer( + controller._final_task_consumer + ) ) - ) - f_task = asyncio.create_task(final_task_consuming_callable()) # noqa: F841 - await asyncio.Future() + return start_consuming_callable, final_task_consuming_callable + + return asyncio.run(launch()) + + +start_consuming_callable, final_task_consuming_callable = get_consuming_callables() + + +async def listening_to_queue() -> None: + h_task = asyncio.create_task(start_consuming_callable()) # noqa: F841 + f_task = asyncio.create_task(final_task_consuming_callable()) # noqa: F841 + while True: + await asyncio.sleep(0.1) -if __name__ == "__main__": - asyncio.run(launch()) +asyncio.run(listening_to_queue())