diff --git a/snowflake_cybersyn_demo/workflows/human_input.py b/snowflake_cybersyn_demo/workflows/human_input.py index 8fbcee9..ea3d1a3 100644 --- a/snowflake_cybersyn_demo/workflows/human_input.py +++ b/snowflake_cybersyn_demo/workflows/human_input.py @@ -1,15 +1,31 @@ -from typing import Callable +from typing import Any, Awaitable, Protocol, runtime_checkable from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step +@runtime_checkable +class HumanInputFn(Protocol): + """Protocol for getting human input.""" + + def __call__(self, prompt: str, **kwargs: Any) -> Awaitable[str]: + ... + + +async def default_human_input_fn(prompt: str, **kwargs: Any) -> str: + return input(prompt) + + class HumanInputWorkflow(Workflow): - input: Callable = input + def __init__( + self, input: HumanInputFn = default_human_input_fn, **kwargs: Any + ): + super().__init__(**kwargs) + self.input = input @step async def human_input(self, ev: StartEvent) -> StopEvent: prompt = str(ev.get("prompt", "")) - human_input = self.input(prompt) + human_input = await self.input(prompt) return StopEvent(result=human_input)