Skip to content

Commit

Permalink
wire human req
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jul 30, 2024
1 parent b287ff0 commit aac96da
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 73 deletions.
24 changes: 16 additions & 8 deletions snowflake_cybersyn_demo/additional_services/human_in_the_loop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from typing import Any, Dict
import queue
from typing import Any, Dict, TypedDict

from llama_agents import HumanService, ServiceComponent
from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue
Expand All @@ -21,21 +22,24 @@
localhost = load_from_env("LOCALHOST")


class HumanRequest(TypedDict):
prompt: str
task_id: str


# # human in the loop function
def human_service_factory(
human_input_request_queue: asyncio.Queue[Dict[str, str]],
human_input_result_queue: asyncio.Queue[str],
human_input_request_queue: queue.Queue[Dict[str, str]],
human_input_result_queue: queue.Queue[str],
) -> HumanService:
async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str:
logger.info("human input fn invoked.")
await human_input_request_queue.put(
{"prompt": prompt, "task_id": task_id}
)
human_input_request_queue.put({"prompt": prompt, "task_id": task_id})
logger.info("placed new prompt in queue.")

# poll until human answer is stored
async def _poll_for_human_input_result() -> str:
return await human_input_result_queue.get()
return human_input_result_queue.get()

try:
human_input = await asyncio.wait_for(
Expand Down Expand Up @@ -67,5 +71,9 @@ async def _poll_for_human_input_result() -> str:


# used by control plane
human_service = human_service_factory(asyncio.Queue(), asyncio.Queue())
human_input_request_queue = queue.Queue()
human_input_result_queue = queue.Queue()
human_service = human_service_factory(
human_input_request_queue, human_input_result_queue
)
human_component = ServiceComponent.from_service_definition(human_service)
43 changes: 1 addition & 42 deletions snowflake_cybersyn_demo/apps/controller.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
import asyncio
import logging
import queue
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Generator, List, Optional

import streamlit as st
from llama_agents import (
CallableMessageConsumer,
LlamaAgentsClient,
QueueMessage,
)
from llama_agents.types import ActionTypes, TaskResult
from llama_agents import LlamaAgentsClient
from llama_index.core.llms import ChatMessage, ChatResponseGen

from snowflake_cybersyn_demo.additional_services.human_in_the_loop import (
human_service_factory,
)

logger = logging.getLogger(__name__)


Expand All @@ -39,18 +28,9 @@ class TaskModel:
class Controller:
def __init__(
self,
human_in_loop_queue: asyncio.Queue,
human_in_loop_result_queue: asyncio.Queue,
submitted_tasks_queue: queue.Queue,
control_plane_host: str = "127.0.0.1",
control_plane_port: Optional[int] = 8000,
):
self.human_in_loop_queue = human_in_loop_queue
self.human_in_loop_result_queue = human_in_loop_result_queue
self.submitted_tasks_queue: queue.Queue[TaskModel] = queue.Queue()
self._human_service = human_service_factory(
human_in_loop_queue, human_in_loop_result_queue
)
self._client = LlamaAgentsClient(
control_plane_url=(
f"http://{control_plane_host}:{control_plane_port}"
Expand All @@ -60,27 +40,6 @@ def __init__(
)
self._step_interval = 0.5
self._timeout = 60
self._raise_timeout = False
self._human_in_the_loop_task: Optional[str] = None
self._human_input: Optional[str] = None
self._final_task_consumer = CallableMessageConsumer(
message_type="human", handler=self._process_completed_task_messages
)
self._completed_tasks_queue: asyncio.Queue[TaskResult] = asyncio.Queue()

async def _process_completed_task_messages(
self, message: QueueMessage, **kwargs: Any
) -> None:
"""Consumer of completed tasks.
By default control plane sends to message consumer of type "human".
The process message logic contained here simply puts the TaskResult into
a queue that is continuosly via a gr.Timer().
"""
if message.action == ActionTypes.COMPLETED_TASK:
task_res = TaskResult(**message.data)
await self._completed_tasks_queue.put(task_res)
logger.info("Added task result to queue")

def _llama_index_stream_wrapper(
self,
Expand Down
115 changes: 92 additions & 23 deletions snowflake_cybersyn_demo/apps/streamlit.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import asyncio
import logging
import queue
import time
import threading
from typing import Dict, List, Optional
import time
from typing import Optional

import pandas as pd
import streamlit as st
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI

from snowflake_cybersyn_demo.apps.async_list import AsyncSafeList
from snowflake_cybersyn_demo.additional_services.human_in_the_loop import (
HumanRequest,
HumanService,
)
from snowflake_cybersyn_demo.apps.controller import (
Controller,
TaskModel,
TaskResult,
TaskStatus,
)
Expand All @@ -24,29 +26,48 @@
llm = OpenAI(model="gpt-4o-mini")
control_plane_host = "0.0.0.0"
control_plane_port = 8001
human_input_request_queue: asyncio.Queue[Dict[str, str]] = asyncio.Queue()
human_input_result_queue: asyncio.Queue[str] = asyncio.Queue()


st.set_page_config(layout="wide")


@st.cache_resource
def startup():
human_input_request_queue: asyncio.Queue[Dict[str, str]] = asyncio.Queue()
human_input_result_queue: asyncio.Queue[str] = asyncio.Queue()
submitted_tasks_queue = queue.Queue()
from snowflake_cybersyn_demo.additional_services.human_in_the_loop import (
human_input_request_queue,
human_input_result_queue,
human_service,
)

completed_tasks_queue = queue.Queue()
controller = Controller(
human_in_loop_queue=human_input_request_queue,
human_in_loop_result_queue=human_input_result_queue,
control_plane_host=control_plane_host,
control_plane_port=control_plane_port,
submitted_tasks_queue=submitted_tasks_queue,
)

async def start_consuming_human_tasks(human_service: HumanService):
human_task_consuming_callable = (
await human_service.message_queue.register_consumer(
human_service.as_consumer()
)
)

ht_task = asyncio.create_task(human_task_consuming_callable())

launch_task = asyncio.create_task(human_service.processing_loop())

await asyncio.Future()

hr_thread = threading.Thread(
name="Human Request thread",
target=asyncio.run,
args=(start_consuming_human_tasks(human_service),),
daemon=False,
)
hr_thread.start()

final_task_consumer = FinalTaskConsumer(
message_queue=controller._human_service.message_queue,
message_queue=human_service.message_queue,
completed_tasks_queue=completed_tasks_queue,
)

Expand All @@ -58,24 +79,33 @@ async def start_consuming_finalized_tasks(final_task_consumer):
await final_task_consuming_callable()

# server thread will remain active as long as streamlit thread is running, or is manually shutdown
thread = threading.Thread(
ft_thread = threading.Thread(
name="Consuming thread",
target=asyncio.run,
args=(start_consuming_finalized_tasks(final_task_consumer),),
daemon=False,
)
thread.start()
ft_thread.start()

time.sleep(5)
st.session_state.consuming = True
logger.info("Started consuming.")

return controller, submitted_tasks_queue, completed_tasks_queue, final_task_consumer
return (
controller,
completed_tasks_queue,
final_task_consumer,
human_input_request_queue,
human_input_result_queue,
)


controller, submitted_tasks_queue, completed_tasks_queue, final_task_consumer = (
startup()
)
(
controller,
completed_tasks_queue,
final_task_consumer,
human_input_request_queue,
human_input_result_queue,
) = startup()


### App
Expand Down Expand Up @@ -124,8 +154,12 @@ async def start_consuming_finalized_tasks(final_task_consumer):
for m in st.session_state.messages
]
)
response = st.write_stream(controller._llama_index_stream_wrapper(stream))
st.session_state.messages.append({"role": "assistant", "content": response})
response = st.write_stream(
controller._llama_index_stream_wrapper(stream)
)
st.session_state.messages.append(
{"role": "assistant", "content": response}
)


@st.experimental_fragment(run_every="30s")
Expand Down Expand Up @@ -161,7 +195,7 @@ def process_completed_tasks(completed_queue: queue.Queue):
task_res = completed_queue.get_nowait()
logger.info("got new task result")
except queue.Empty:
logger.info(f"task result queue is empty.")
logger.info("task result queue is empty.")

if task_res:
try:
Expand All @@ -185,3 +219,38 @@ def process_completed_tasks(completed_queue: queue.Queue):


process_completed_tasks(completed_queue=completed_tasks_queue)


@st.experimental_fragment(run_every=5)
def process_human_input_requests(
human_requests_queue: queue.Queue[HumanRequest],
):
human_req: Optional[HumanRequest] = None
try:
human_req = human_requests_queue.get_nowait()
logger.info("got new human request")
except queue.Empty:
logger.info("human request queue is empty.")

if human_req:
try:
task_list = st.session_state.get("submitted_tasks")
print(f"submitted tasks: {task_list}")
ix, task = next(
(ix, t)
for ix, t in enumerate(task_list)
if t.task_id == human_req["task_id"]
)
task.status = TaskStatus.COMPLETED
task.chat_history.append(
ChatMessage(role="assistant", content=human_req["prompt"])
)
del task_list[ix]
st.session_state.submitted_tasks = task_list
st.session_state.human_required_tasks.append(task)
logger.info("updated submitted and human required tasks list.")
except StopIteration:
raise ValueError("Cannot find task in list of tasks.")


process_human_input_requests(human_requests_queue=human_input_request_queue)

0 comments on commit aac96da

Please sign in to comment.