From 5a360c1c4ac292c70ee1c4dd2e49c4f46fafc948 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Tue, 30 Jul 2024 01:37:56 -0400 Subject: [PATCH] simplify human_service --- .../additional_services/human_in_the_loop.py | 74 +++++++++---------- snowflake_cybersyn_demo/apps/streamlit.py | 18 ++--- 2 files changed, 39 insertions(+), 53 deletions(-) diff --git a/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py b/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py index 97882c8..bd87b4f 100644 --- a/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py +++ b/snowflake_cybersyn_demo/additional_services/human_in_the_loop.py @@ -28,52 +28,44 @@ class HumanRequest(TypedDict): # # human in the loop function -def human_service_factory( - human_input_request_queue: queue.Queue[Dict[str, str]], - human_input_result_queue: queue.Queue[str], -) -> HumanService: - async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str: - logger.info("human input fn invoked.") - human_input_request_queue.put({"prompt": prompt, "task_id": task_id}) - logger.info("placed new prompt in queue.") +human_input_request_queue = queue.Queue() +human_input_result_queue = queue.Queue() - # poll until human answer is stored - async def _poll_for_human_input_result() -> str: - return human_input_result_queue.get() - try: - human_input = await asyncio.wait_for( - _poll_for_human_input_result(), - timeout=6000, - ) - logger.info(f"Recieved human input: {human_input}") - except ( - asyncio.exceptions.TimeoutError, - asyncio.TimeoutError, - TimeoutError, - ): - logger.info(f"Timeout reached for tool_call with prompt {prompt}") - human_input = "Something went wrong." +async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str: + logger.info("human input fn invoked.") + human_input_request_queue.put({"prompt": prompt, "task_id": task_id}) + logger.info("placed new prompt in queue.") - return human_input + # poll until human answer is stored + async def _poll_for_human_input_result() -> str: + return human_input_result_queue.get() - # create our multi-agent framework components - message_queue = RabbitMQMessageQueue( - url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" - ) - human_service = HumanService( - message_queue=message_queue, - description="Answers queries about math.", - fn_input=human_input_fn, - human_input_prompt="{input_str}", - ) - return human_service + try: + human_input = await asyncio.wait_for( + _poll_for_human_input_result(), + timeout=6000, + ) + logger.info(f"Recieved human input: {human_input}") + except ( + asyncio.exceptions.TimeoutError, + asyncio.TimeoutError, + TimeoutError, + ): + logger.info(f"Timeout reached for tool_call with prompt {prompt}") + human_input = "Something went wrong." + return human_input -# used by control plane -human_input_request_queue = queue.Queue() -human_input_result_queue = queue.Queue() -human_service = human_service_factory( - human_input_request_queue, human_input_result_queue + +# create our multi-agent framework components +message_queue = RabbitMQMessageQueue( + url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" +) +human_service = HumanService( + message_queue=message_queue, + description="Answers queries about math.", + fn_input=human_input_fn, + human_input_prompt="{input_str}", ) human_component = ServiceComponent.from_service_definition(human_service) diff --git a/snowflake_cybersyn_demo/apps/streamlit.py b/snowflake_cybersyn_demo/apps/streamlit.py index dfa6ca6..efe6c38 100644 --- a/snowflake_cybersyn_demo/apps/streamlit.py +++ b/snowflake_cybersyn_demo/apps/streamlit.py @@ -16,7 +16,6 @@ ) from snowflake_cybersyn_demo.apps.controller import ( Controller, - TaskResult, TaskStatus, ) from snowflake_cybersyn_demo.apps.final_task_consumer import FinalTaskConsumer @@ -37,6 +36,7 @@ def startup(): human_input_request_queue, human_input_result_queue, human_service, + message_queue, ) completed_tasks_queue = queue.Queue() @@ -46,10 +46,8 @@ def startup(): ) async def start_consuming_human_tasks(human_service: HumanService): - human_task_consuming_callable = ( - await human_service.message_queue.register_consumer( - human_service.as_consumer() - ) + human_task_consuming_callable = await message_queue.register_consumer( + human_service.as_consumer() ) ht_task = asyncio.create_task(human_task_consuming_callable()) @@ -67,7 +65,7 @@ async def start_consuming_human_tasks(human_service: HumanService): hr_thread.start() final_task_consumer = FinalTaskConsumer( - message_queue=human_service.message_queue, + message_queue=message_queue, completed_tasks_queue=completed_tasks_queue, ) @@ -154,12 +152,8 @@ async def start_consuming_finalized_tasks(final_task_consumer): for m in st.session_state.messages ] ) - response = st.write_stream( - controller._llama_index_stream_wrapper(stream) - ) - st.session_state.messages.append( - {"role": "assistant", "content": response} - ) + response = st.write_stream(controller._llama_index_stream_wrapper(stream)) + st.session_state.messages.append({"role": "assistant", "content": response}) @st.experimental_fragment(run_every="30s")