diff --git a/snowflake_cybersyn_demo/workflows/sql_queries.py b/snowflake_cybersyn_demo/workflows/_db.py similarity index 65% rename from snowflake_cybersyn_demo/workflows/sql_queries.py rename to snowflake_cybersyn_demo/workflows/_db.py index c4f85b1..5b92682 100644 --- a/snowflake_cybersyn_demo/workflows/sql_queries.py +++ b/snowflake_cybersyn_demo/workflows/_db.py @@ -12,7 +12,18 @@ snowflake_role = load_from_env("SNOWFLAKE_ROLE") localhost = load_from_env("LOCALHOST") -SQL_QUERY_TEMPLATE = """ +CANDIDATE_LIST_SQL_QUERY_TEMPLATE = """ +SELECT DISTINCT att.product, +FROM cybersyn.bureau_of_labor_statistics_price_timeseries AS ts +JOIN cybersyn.bureau_of_labor_statistics_price_attributes AS att + ON (ts.variable = att.variable) +WHERE ts.date >= '2021-01-01' + AND att.report = 'Average Price' + AND att.product ILIKE '{good}%'; +""" + + +TIMESERIES_SQL_QUERY_TEMPLATE = """ SELECT ts.date, att.variable_name, ts.value @@ -26,9 +37,34 @@ """ +def get_list_of_candidate_goods(good: str) -> List[str]: + """Returns a list of goods that exist in the database. + + The list of goods is represented as a string separated by '\n'.""" + query = CANDIDATE_LIST_SQL_QUERY_TEMPLATE.format(good=good) + url = URL( + account=snowflake_account, + user=snowflake_user, + password=snowflake_password, + database="FINANCIAL__ECONOMIC_ESSENTIALS", + schema="CYBERSYN", + warehouse="COMPUTE_WH", + role=snowflake_role, + ) + + engine = create_engine(url) + try: + connection = engine.connect() + results = connection.execute(text(query)) + finally: + connection.close() + + return [f"{ix+1}. {str(el[0])}" for ix, el in enumerate(results)] + + def get_time_series_of_good(good: str) -> str: """Create a time series of the average price paid for a good nationwide starting in 2021.""" - query = SQL_QUERY_TEMPLATE.format(good=good) + query = TIMESERIES_SQL_QUERY_TEMPLATE.format(good=good) url = URL( account=snowflake_account, user=snowflake_user, diff --git a/snowflake_cybersyn_demo/workflows/financial_and_economic_essentials.py b/snowflake_cybersyn_demo/workflows/financial_and_economic_essentials.py index a9ecc47..9879d53 100644 --- a/snowflake_cybersyn_demo/workflows/financial_and_economic_essentials.py +++ b/snowflake_cybersyn_demo/workflows/financial_and_economic_essentials.py @@ -8,54 +8,8 @@ step, ) from llama_index.llms.openai import OpenAI -from snowflake.sqlalchemy import URL -from sqlalchemy import create_engine, text -from snowflake_cybersyn_demo.utils import load_from_env -from snowflake_cybersyn_demo.workflows.sql_queries import ( - get_time_series_of_good, - perform_price_aggregation, -) - -snowflake_user = load_from_env("SNOWFLAKE_USERNAME") -snowflake_password = load_from_env("SNOWFLAKE_PASSWORD") -snowflake_account = load_from_env("SNOWFLAKE_ACCOUNT") -snowflake_role = load_from_env("SNOWFLAKE_ROLE") - -SQL_QUERY_TEMPLATE = """ -SELECT DISTINCT att.product, -FROM cybersyn.bureau_of_labor_statistics_price_timeseries AS ts -JOIN cybersyn.bureau_of_labor_statistics_price_attributes AS att - ON (ts.variable = att.variable) -WHERE ts.date >= '2021-01-01' - AND att.report = 'Average Price' - AND att.product ILIKE '{good}%'; -""" - - -def get_list_of_candidate_goods(good: str) -> List[str]: - """Returns a list of goods that exist in the database. - - The list of goods is represented as a string separated by '\n'.""" - query = SQL_QUERY_TEMPLATE.format(good=good) - url = URL( - account=snowflake_account, - user=snowflake_user, - password=snowflake_password, - database="FINANCIAL__ECONOMIC_ESSENTIALS", - schema="CYBERSYN", - warehouse="COMPUTE_WH", - role=snowflake_role, - ) - - engine = create_engine(url) - try: - connection = engine.connect() - results = connection.execute(text(query)) - finally: - connection.close() - - return [f"{ix+1}. {str(el[0])}" for ix, el in enumerate(results)] +import snowflake_cybersyn_demo.workflows._db as db class CandidateLookupEvent(Event): @@ -74,7 +28,7 @@ async def retrieve_candidates_from_db( ) -> CandidateLookupEvent: # Your workflow logic here good = str(ev.get("good", "")) - candidates = get_list_of_candidate_goods(good=good) + candidates = db.get_list_of_candidate_goods(good=good) return CandidateLookupEvent(candidates=candidates) @step @@ -108,9 +62,9 @@ async def human_input(self, ev: CandidateLookupEvent) -> HumanInputEvent: @step async def get_time_series_data(self, ev: HumanInputEvent) -> StopEvent: - timeseries_data_str = get_time_series_of_good(good=ev.selected_good) + timeseries_data_str = db.get_time_series_of_good(good=ev.selected_good) # aggregation - aggregated_timeseries_data = perform_price_aggregation( + aggregated_timeseries_data = db.perform_price_aggregation( timeseries_data_str ) return StopEvent(result=aggregated_timeseries_data)