Skip to content

Commit

Permalink
wire popover and change status from human required to complete
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Aug 2, 2024
1 parent 3736646 commit 83d286c
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str:

# poll until human answer is stored
async def _poll_for_human_input_result() -> str:
return human_input_result_queue.get()
human_input = None
while human_input is None:
try:
human_input = human_input_result_queue.get_nowait()
except queue.Empty:
human_input = None
await asyncio.sleep(0.1)
logger.info("human input recieved")
return human_input

try:
human_input = await asyncio.wait_for(
Expand Down
63 changes: 46 additions & 17 deletions snowflake_cybersyn_demo/apps/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import queue
import time
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -78,6 +79,18 @@ def handle_task_submission(self) -> None:
st.session_state.current_task = task
st.session_state.task_input = ""

def get_human_input_handler(
self, human_input_result_queue: queue.Queue
) -> Callable:
def human_input_handler() -> None:
human_input = st.session_state.human_input
if human_input == "":
return
human_input_result_queue.put_nowait(human_input)
logger.info("pushed human input to human input result queue.")

return human_input_handler

def update_associated_task_to_completed_status(
self,
task_res: TaskResult,
Expand All @@ -87,24 +100,40 @@ def update_associated_task_to_completed_status(
Update session_state lists as well.
"""
try:
task_list = st.session_state.get("submitted_tasks")
print(f"submitted tasks: {task_list}")
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == task_res.task_id
)
task.status = TaskStatus.COMPLETED
task.history.append(
ChatMessage(role="assistant", content=task_res.result)

def remove_task_from_list(
task_list: List[TaskModel],
) -> List[TaskModel]:
try:
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == task_res.task_id
)
task.status = TaskStatus.COMPLETED
task.history.append(
ChatMessage(role="assistant", content=task_res.result)
)
del task_list[ix]
st.session_state.completed_tasks.append(task)
logger.info("updated submitted and completed tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")
return task_list

submitted_tasks = st.session_state.get("submitted_tasks")
human_required_tasks = st.session_state.get("human_required_tasks")

if task_res.task_id in [t.task_id for t in submitted_tasks]:
updated_task_list = remove_task_from_list(submitted_tasks)
st.session_state.submitted_tasks = updated_task_list
elif task_res.task_id in [t.task_id for t in human_required_tasks]:
updated_task_list = remove_task_from_list(human_required_tasks)
st.session_state.human_required_tasks = updated_task_list
else:
raise ValueError(
"Completed task not in submitted or human_required lists."
)
del task_list[ix]
st.session_state.submitted_tasks = task_list
st.session_state.completed_tasks.append(task)
logger.info("updated submitted and completed tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")

def update_associated_task_to_human_required_status(
self,
Expand Down
101 changes: 66 additions & 35 deletions snowflake_cybersyn_demo/apps/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ async def start_consuming_finalized_tasks(
st.session_state.messages = []
if "current_task" not in st.session_state:
st.session_state.current_task = None
if "human_input" not in st.session_state:
st.session_state.human_input = ""


left, right = st.columns([1, 2], vertical_alignment="top")
Expand All @@ -154,39 +156,46 @@ async def start_consuming_finalized_tasks(


def chat_window() -> None:
with st.sidebar:

@st.experimental_fragment(run_every="5s")
def show_chat_window() -> None:
messages_container = st.container(height=500)
with messages_container:
if st.session_state.current_task:
messages = [
m.dict() for m in st.session_state.current_task.history
]
for message in messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
else:
st.empty()

show_chat_window()

if _ := st.chat_input("What is up?"):
pass
# st.session_state.messages.append({"role": "user", "content": prompt})
# with st.chat_message("user"):
# st.markdown(prompt)

# with st.chat_message("assistant"):
# stream = llm.stream_chat(
# messages=[
# ChatMessage(role=m["role"], content=m["content"])
# for m in st.session_state.messages
# ]
# )
# response = st.write_stream(controller.llama_index_stream_wrapper(stream))
# st.session_state.messages.append({"role": "assistant", "content": response})
pass
# with st.sidebar:

# @st.experimental_fragment(run_every="5s")
# def show_chat_window() -> None:
# messages_container = st.container(height=500)
# with messages_container:
# if st.session_state.current_task:
# messages = [m.dict() for m in st.session_state.current_task.history]
# for message in messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# else:
# st.empty()

# show_chat_window()

# if human_input := st.chat_input("Provide human input."):
# if st.session_state.current_task:
# st.session_state.current_task.history.append(
# ChatMessage(role="user", content=human_input)
# )
# human_input_result_queue.put(human_input)
# time.sleep(1)
# logger.info("pushed human input to human input result queue.")

# logger.info(f"HUMAN INPUT: {human_input}")
# st.session_state.messages.append({"role": "user", "content": prompt})
# with st.chat_message("user"):
# st.markdown(prompt)

# with st.chat_message("assistant"):
# stream = llm.stream_chat(
# messages=[
# ChatMessage(role=m["role"], content=m["content"])
# for m in st.session_state.messages
# ]
# )
# response = st.write_stream(controller.llama_index_stream_wrapper(stream))
# st.session_state.messages.append({"role": "assistant", "content": response})


@st.experimental_fragment(run_every="30s")
Expand All @@ -210,10 +219,16 @@ def task_df() -> None:
+ ["human_required"] * len(st.session_state.human_required_tasks)
+ ["completed"] * len(st.session_state.completed_tasks)
)
data = {"task_id": task_ids, "input": tasks, "status": status}

data = {
"task_id": task_ids,
"input": tasks,
"status": status,
}

logger.info(f"data: {data}")
df = pd.DataFrame(data)
st.dataframe(
event = st.dataframe(
df,
hide_index=True,
selection_mode="single-row",
Expand All @@ -222,6 +237,22 @@ def task_df() -> None:
key="task_df",
)

popover_enabled = (
len(event.selection["rows"]) > 0
and st.session_state.current_task.status == "human_required"
)
with st.popover("Human Input", disabled=not popover_enabled):
if popover_enabled:
human_prompt = st.session_state.current_task.history[-1].content
st.markdown(human_prompt)
st.text_input(
"Provide human input",
key="human_input",
on_change=controller.get_human_input_handler(
human_input_result_queue
),
)


task_df()

Expand Down

0 comments on commit 83d286c

Please sign in to comment.