diff --git a/snowflake_cybersyn_demo/apps/async_list.py b/snowflake_cybersyn_demo/apps/async_list.py index b93af73..95ab012 100644 --- a/snowflake_cybersyn_demo/apps/async_list.py +++ b/snowflake_cybersyn_demo/apps/async_list.py @@ -1,5 +1,3 @@ -from typing import List - from asyncio import Lock @@ -22,6 +20,10 @@ async def pop(self): async with self._lock: return self._list.pop() + async def delete(self, index): + async with self._lock: + del self._list[index] + async def get(self, index): async with self._lock: return self._list[index] diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index 857afc4..c3d98f2 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -1,20 +1,19 @@ import asyncio import logging from typing import Dict, List -from contextvars import ContextVar import pandas as pd import streamlit as st from llama_index.core.llms import ChatMessage from llama_index.llms.openai import OpenAI +from snowflake_cybersyn_demo.apps.async_list import AsyncSafeList from snowflake_cybersyn_demo.apps.controller import ( Controller, TaskModel, TaskResult, TaskStatus, ) -from snowflake_cybersyn_demo.apps.async_list import AsyncSafeList logger = logging.getLogger(__name__) @@ -76,8 +75,12 @@ for m in st.session_state.messages ] ) - response = st.write_stream(controller._llama_index_stream_wrapper(stream)) - st.session_state.messages.append({"role": "assistant", "content": response}) + response = st.write_stream( + controller._llama_index_stream_wrapper(stream) + ) + st.session_state.messages.append( + {"role": "assistant", "content": response} + ) @st.cache_resource @@ -100,7 +103,10 @@ async def launch() -> None: return asyncio.run(launch()) -start_consuming_callable, final_task_consuming_callable = get_consuming_callables() +( + start_consuming_callable, + final_task_consuming_callable, +) = get_consuming_callables() def remove_from_list_closure( @@ -115,10 +121,14 @@ def remove_from_list_closure( over to the completed list. """ ix, task = next( - (ix, t) for ix, t in enumerate(task_list) if t.task_id == task_res.task_id + (ix, t) + for ix, t in enumerate(task_list) + if t.task_id == task_res.task_id ) task.status = TaskStatus.COMPLETED - task.chat_history.append(ChatMessage(role="assistant", content=task_res.result)) + task.chat_history.append( + ChatMessage(role="assistant", content=task_res.result) + ) del task_list[ix] # if current_task: @@ -134,10 +144,11 @@ def remove_from_list_closure( @st.cache_resource def get_async_safe_lists(): submitted_tasks = AsyncSafeList() - return submitted_tasks + completed_tasks = AsyncSafeList() + return submitted_tasks, completed_tasks -submitted_tasks = get_async_safe_lists() +submitted_tasks, completed_tasks = get_async_safe_lists() async def listening_to_queue(ctr) -> None: @@ -146,45 +157,48 @@ async def listening_to_queue(ctr) -> None: f_task = asyncio.create_task(final_task_consuming_callable()) # noqa: F841 human_required_tasks = [] - completed_tasks = [] while True: logger.info(f"submitted: {submitted_tasks._list}") - # logger.info(f"completed: {completed_tasks}") + logger.info(f"completed: {completed_tasks._list}") try: - new_task: TaskModel = controller._submitted_tasks_queue.get_nowait() + new_task: TaskModel = ( + controller._submitted_tasks_queue.get_nowait() + ) await submitted_tasks.append(new_task) logger.info("got new submitted task") logger.info(f"submitted: {submitted_tasks}") except asyncio.QueueEmpty: logger.info("task completion queue is empty") - # try: - # task_res: TaskResult = controller._completed_tasks_queue.get_nowait() - # logger.info("got new completed task result") - # except asyncio.QueueEmpty: - # task_res = None - # logger.info("task completion queue is empty") - - # if task_res: - # if task_res.task_id in [t.task_id for t in submitted_tasks]: - # ix, task = next( - # (ix, t) - # for ix, t in enumerate(submitted_tasks) - # if t.task_id == task_res.task_id - # ) - # task.status = TaskStatus.COMPLETED - # task.chat_history.append( - # ChatMessage(role="assistant", content=task_res.result) - # ) - # del submitted_tasks[ix] - # completed_tasks.append(task) - # logger.info(f"updated task status from submitted to completed.") - # elif task_res.task_id in [t.task_id for t in human_required_tasks]: - # remove_from_list_closure( - # st.session_state.human_required_tasks, - # TaskStatus.HUMAN_REQUIRED, - # ) + try: + task_res: TaskResult = ( + controller._completed_tasks_queue.get_nowait() + ) + logger.info("got new completed task result") + except asyncio.QueueEmpty: + task_res = None + logger.info("task completion queue is empty") + + if task_res: + if task_res.task_id in [t.task_id for t in submitted_tasks]: + ix, task = next( + (ix, t) + for ix, t in enumerate(submitted_tasks) + if t.task_id == task_res.task_id + ) + task.status = TaskStatus.COMPLETED + task.chat_history.append( + ChatMessage(role="assistant", content=task_res.result) + ) + await submitted_tasks.delete(ix) + await completed_tasks.append(task) + logger.info("updated task status from submitted to completed.") + elif task_res.task_id in [t.task_id for t in human_required_tasks]: + remove_from_list_closure( + st.session_state.human_required_tasks, + TaskStatus.HUMAN_REQUIRED, + ) ctr.text("Task Status") tasks = ( @@ -194,10 +208,11 @@ async def listening_to_queue(ctr) -> None: ) n_submitted = await submitted_tasks.length() + n_completed = await completed_tasks.length() status = ( ["submitted"] * n_submitted + ["human_required"] * len(human_required_tasks) - + ["completed"] * len(completed_tasks) + + ["completed"] * n_completed ) data = {"tasks": tasks, "status": status} logger.info(f"data: {data}") @@ -205,7 +220,7 @@ async def listening_to_queue(ctr) -> None: ctr.dataframe( df, selection_mode="single-row", use_container_width=True ) # Same as st.write(df) - await asyncio.sleep(5) + await asyncio.sleep(1) bottom = st.empty()