Skip to content

Commit

Permalink
async-safe list works
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jul 26, 2024
1 parent 8a7d748 commit 78eee8b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 42 deletions.
6 changes: 4 additions & 2 deletions snowflake_cybersyn_demo/apps/async_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from asyncio import Lock


Expand All @@ -22,6 +20,10 @@ async def pop(self):
async with self._lock:
return self._list.pop()

async def delete(self, index):
async with self._lock:
del self._list[index]

async def get(self, index):
async with self._lock:
return self._list[index]
Expand Down
95 changes: 55 additions & 40 deletions snowflake_cybersyn_demo/apps/streamlit.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import asyncio
import logging
from typing import Dict, List
from contextvars import ContextVar

import pandas as pd
import streamlit as st
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI

from snowflake_cybersyn_demo.apps.async_list import AsyncSafeList
from snowflake_cybersyn_demo.apps.controller import (
Controller,
TaskModel,
TaskResult,
TaskStatus,
)
from snowflake_cybersyn_demo.apps.async_list import AsyncSafeList

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,8 +75,12 @@
for m in st.session_state.messages
]
)
response = st.write_stream(controller._llama_index_stream_wrapper(stream))
st.session_state.messages.append({"role": "assistant", "content": response})
response = st.write_stream(
controller._llama_index_stream_wrapper(stream)
)
st.session_state.messages.append(
{"role": "assistant", "content": response}
)


@st.cache_resource
Expand All @@ -100,7 +103,10 @@ async def launch() -> None:
return asyncio.run(launch())


start_consuming_callable, final_task_consuming_callable = get_consuming_callables()
(
start_consuming_callable,
final_task_consuming_callable,
) = get_consuming_callables()


def remove_from_list_closure(
Expand All @@ -115,10 +121,14 @@ def remove_from_list_closure(
over to the completed list.
"""
ix, task = next(
(ix, t) for ix, t in enumerate(task_list) if t.task_id == task_res.task_id
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == task_res.task_id
)
task.status = TaskStatus.COMPLETED
task.chat_history.append(ChatMessage(role="assistant", content=task_res.result))
task.chat_history.append(
ChatMessage(role="assistant", content=task_res.result)
)
del task_list[ix]

# if current_task:
Expand All @@ -134,10 +144,11 @@ def remove_from_list_closure(
@st.cache_resource
def get_async_safe_lists():
submitted_tasks = AsyncSafeList()
return submitted_tasks
completed_tasks = AsyncSafeList()
return submitted_tasks, completed_tasks


submitted_tasks = get_async_safe_lists()
submitted_tasks, completed_tasks = get_async_safe_lists()


async def listening_to_queue(ctr) -> None:
Expand All @@ -146,45 +157,48 @@ async def listening_to_queue(ctr) -> None:
f_task = asyncio.create_task(final_task_consuming_callable()) # noqa: F841

human_required_tasks = []
completed_tasks = []
while True:
logger.info(f"submitted: {submitted_tasks._list}")
# logger.info(f"completed: {completed_tasks}")
logger.info(f"completed: {completed_tasks._list}")

try:
new_task: TaskModel = controller._submitted_tasks_queue.get_nowait()
new_task: TaskModel = (
controller._submitted_tasks_queue.get_nowait()
)
await submitted_tasks.append(new_task)
logger.info("got new submitted task")
logger.info(f"submitted: {submitted_tasks}")
except asyncio.QueueEmpty:
logger.info("task completion queue is empty")

# try:
# task_res: TaskResult = controller._completed_tasks_queue.get_nowait()
# logger.info("got new completed task result")
# except asyncio.QueueEmpty:
# task_res = None
# logger.info("task completion queue is empty")

# if task_res:
# if task_res.task_id in [t.task_id for t in submitted_tasks]:
# ix, task = next(
# (ix, t)
# for ix, t in enumerate(submitted_tasks)
# if t.task_id == task_res.task_id
# )
# task.status = TaskStatus.COMPLETED
# task.chat_history.append(
# ChatMessage(role="assistant", content=task_res.result)
# )
# del submitted_tasks[ix]
# completed_tasks.append(task)
# logger.info(f"updated task status from submitted to completed.")
# elif task_res.task_id in [t.task_id for t in human_required_tasks]:
# remove_from_list_closure(
# st.session_state.human_required_tasks,
# TaskStatus.HUMAN_REQUIRED,
# )
try:
task_res: TaskResult = (
controller._completed_tasks_queue.get_nowait()
)
logger.info("got new completed task result")
except asyncio.QueueEmpty:
task_res = None
logger.info("task completion queue is empty")

if task_res:
if task_res.task_id in [t.task_id for t in submitted_tasks]:
ix, task = next(
(ix, t)
for ix, t in enumerate(submitted_tasks)
if t.task_id == task_res.task_id
)
task.status = TaskStatus.COMPLETED
task.chat_history.append(
ChatMessage(role="assistant", content=task_res.result)
)
await submitted_tasks.delete(ix)
await completed_tasks.append(task)
logger.info("updated task status from submitted to completed.")
elif task_res.task_id in [t.task_id for t in human_required_tasks]:
remove_from_list_closure(
st.session_state.human_required_tasks,
TaskStatus.HUMAN_REQUIRED,
)

ctr.text("Task Status")
tasks = (
Expand All @@ -194,18 +208,19 @@ async def listening_to_queue(ctr) -> None:
)

n_submitted = await submitted_tasks.length()
n_completed = await completed_tasks.length()
status = (
["submitted"] * n_submitted
+ ["human_required"] * len(human_required_tasks)
+ ["completed"] * len(completed_tasks)
+ ["completed"] * n_completed
)
data = {"tasks": tasks, "status": status}
logger.info(f"data: {data}")
df = pd.DataFrame(data)
ctr.dataframe(
df, selection_mode="single-row", use_container_width=True
) # Same as st.write(df)
await asyncio.sleep(5)
await asyncio.sleep(1)


bottom = st.empty()
Expand Down

0 comments on commit 78eee8b

Please sign in to comment.