Skip to content

Commit

Permalink
simplify human_service
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jul 30, 2024
1 parent aac96da commit 5a360c1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 53 deletions.
74 changes: 33 additions & 41 deletions snowflake_cybersyn_demo/additional_services/human_in_the_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,44 @@ class HumanRequest(TypedDict):


# # human in the loop function
def human_service_factory(
human_input_request_queue: queue.Queue[Dict[str, str]],
human_input_result_queue: queue.Queue[str],
) -> HumanService:
async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str:
logger.info("human input fn invoked.")
human_input_request_queue.put({"prompt": prompt, "task_id": task_id})
logger.info("placed new prompt in queue.")
human_input_request_queue = queue.Queue()
human_input_result_queue = queue.Queue()

# poll until human answer is stored
async def _poll_for_human_input_result() -> str:
return human_input_result_queue.get()

try:
human_input = await asyncio.wait_for(
_poll_for_human_input_result(),
timeout=6000,
)
logger.info(f"Recieved human input: {human_input}")
except (
asyncio.exceptions.TimeoutError,
asyncio.TimeoutError,
TimeoutError,
):
logger.info(f"Timeout reached for tool_call with prompt {prompt}")
human_input = "Something went wrong."
async def human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str:
logger.info("human input fn invoked.")
human_input_request_queue.put({"prompt": prompt, "task_id": task_id})
logger.info("placed new prompt in queue.")

return human_input
# poll until human answer is stored
async def _poll_for_human_input_result() -> str:
return human_input_result_queue.get()

# create our multi-agent framework components
message_queue = RabbitMQMessageQueue(
url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/"
)
human_service = HumanService(
message_queue=message_queue,
description="Answers queries about math.",
fn_input=human_input_fn,
human_input_prompt="{input_str}",
)
return human_service
try:
human_input = await asyncio.wait_for(
_poll_for_human_input_result(),
timeout=6000,
)
logger.info(f"Recieved human input: {human_input}")
except (
asyncio.exceptions.TimeoutError,
asyncio.TimeoutError,
TimeoutError,
):
logger.info(f"Timeout reached for tool_call with prompt {prompt}")
human_input = "Something went wrong."

return human_input

# used by control plane
human_input_request_queue = queue.Queue()
human_input_result_queue = queue.Queue()
human_service = human_service_factory(
human_input_request_queue, human_input_result_queue

# create our multi-agent framework components
message_queue = RabbitMQMessageQueue(
url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/"
)
human_service = HumanService(
message_queue=message_queue,
description="Answers queries about math.",
fn_input=human_input_fn,
human_input_prompt="{input_str}",
)
human_component = ServiceComponent.from_service_definition(human_service)
18 changes: 6 additions & 12 deletions snowflake_cybersyn_demo/apps/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from snowflake_cybersyn_demo.apps.controller import (
Controller,
TaskResult,
TaskStatus,
)
from snowflake_cybersyn_demo.apps.final_task_consumer import FinalTaskConsumer
Expand All @@ -37,6 +36,7 @@ def startup():
human_input_request_queue,
human_input_result_queue,
human_service,
message_queue,
)

completed_tasks_queue = queue.Queue()
Expand All @@ -46,10 +46,8 @@ def startup():
)

async def start_consuming_human_tasks(human_service: HumanService):
human_task_consuming_callable = (
await human_service.message_queue.register_consumer(
human_service.as_consumer()
)
human_task_consuming_callable = await message_queue.register_consumer(
human_service.as_consumer()
)

ht_task = asyncio.create_task(human_task_consuming_callable())
Expand All @@ -67,7 +65,7 @@ async def start_consuming_human_tasks(human_service: HumanService):
hr_thread.start()

final_task_consumer = FinalTaskConsumer(
message_queue=human_service.message_queue,
message_queue=message_queue,
completed_tasks_queue=completed_tasks_queue,
)

Expand Down Expand Up @@ -154,12 +152,8 @@ async def start_consuming_finalized_tasks(final_task_consumer):
for m in st.session_state.messages
]
)
response = st.write_stream(
controller._llama_index_stream_wrapper(stream)
)
st.session_state.messages.append(
{"role": "assistant", "content": response}
)
response = st.write_stream(controller._llama_index_stream_wrapper(stream))
st.session_state.messages.append({"role": "assistant", "content": response})


@st.experimental_fragment(run_every="30s")
Expand Down

0 comments on commit 5a360c1

Please sign in to comment.