Skip to content

Commit

Permalink
use nested workflow for human in the loop
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 12, 2024
1 parent 7155a04 commit b7ab522
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions snowflake_cybersyn_demo/workflows/government_essentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from llama_index.llms.openai import OpenAI

import snowflake_cybersyn_demo.workflows._db as db
from snowflake_cybersyn_demo.workflows.human_input import HumanInputWorkflow


class StatisticsLookupEvent(Event):
Expand All @@ -34,14 +35,18 @@ async def retrieve_candidates_from_db(
return StatisticsLookupEvent(statistic_variables=stats_vars, city=city)

@step
async def human_input(self, ev: StatisticsLookupEvent) -> HumanInputEvent:
async def human_input(
self,
ev: StatisticsLookupEvent,
human_input_workflow: HumanInputWorkflow,
) -> HumanInputEvent:
stats_vars = "\n".join(ev.statistic_variables)
human_prompt = (
"List of statistic variables that exist in the database are provided below."
f"{stats_vars}"
"\n\nPlease select one.:\n\n"
)
human_input = input(human_prompt)
human_input = await human_input_workflow.run(prompt=human_prompt)

# use llm to clean up selection
llm = OpenAI("gpt-4o")
Expand Down Expand Up @@ -77,6 +82,7 @@ async def get_time_series_data(self, ev: HumanInputEvent) -> StopEvent:
# Local Testing
async def _test_workflow() -> None:
w = GovtEssentialsStatisticsWorkflow(timeout=None, verbose=False)
w.add_workflows(human_input_workflow=HumanInputWorkflow())
result = await w.run(city="New York")
print(str(result))

Expand Down

0 comments on commit b7ab522

Please sign in to comment.