Skip to content

Commit

Permalink
refactor clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 11, 2024
1 parent 01c3c04 commit 7a7378f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@
snowflake_role = load_from_env("SNOWFLAKE_ROLE")
localhost = load_from_env("LOCALHOST")

SQL_QUERY_TEMPLATE = """
CANDIDATE_LIST_SQL_QUERY_TEMPLATE = """
SELECT DISTINCT att.product,
FROM cybersyn.bureau_of_labor_statistics_price_timeseries AS ts
JOIN cybersyn.bureau_of_labor_statistics_price_attributes AS att
ON (ts.variable = att.variable)
WHERE ts.date >= '2021-01-01'
AND att.report = 'Average Price'
AND att.product ILIKE '{good}%';
"""


TIMESERIES_SQL_QUERY_TEMPLATE = """
SELECT ts.date,
att.variable_name,
ts.value
Expand All @@ -26,9 +37,34 @@
"""


def get_list_of_candidate_goods(good: str) -> List[str]:
"""Returns a list of goods that exist in the database.
The list of goods is represented as a string separated by '\n'."""
query = CANDIDATE_LIST_SQL_QUERY_TEMPLATE.format(good=good)
url = URL(
account=snowflake_account,
user=snowflake_user,
password=snowflake_password,
database="FINANCIAL__ECONOMIC_ESSENTIALS",
schema="CYBERSYN",
warehouse="COMPUTE_WH",
role=snowflake_role,
)

engine = create_engine(url)
try:
connection = engine.connect()
results = connection.execute(text(query))
finally:
connection.close()

return [f"{ix+1}. {str(el[0])}" for ix, el in enumerate(results)]


def get_time_series_of_good(good: str) -> str:
"""Create a time series of the average price paid for a good nationwide starting in 2021."""
query = SQL_QUERY_TEMPLATE.format(good=good)
query = TIMESERIES_SQL_QUERY_TEMPLATE.format(good=good)
url = URL(
account=snowflake_account,
user=snowflake_user,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,8 @@
step,
)
from llama_index.llms.openai import OpenAI
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine, text

from snowflake_cybersyn_demo.utils import load_from_env
from snowflake_cybersyn_demo.workflows.sql_queries import (
get_time_series_of_good,
perform_price_aggregation,
)

snowflake_user = load_from_env("SNOWFLAKE_USERNAME")
snowflake_password = load_from_env("SNOWFLAKE_PASSWORD")
snowflake_account = load_from_env("SNOWFLAKE_ACCOUNT")
snowflake_role = load_from_env("SNOWFLAKE_ROLE")

SQL_QUERY_TEMPLATE = """
SELECT DISTINCT att.product,
FROM cybersyn.bureau_of_labor_statistics_price_timeseries AS ts
JOIN cybersyn.bureau_of_labor_statistics_price_attributes AS att
ON (ts.variable = att.variable)
WHERE ts.date >= '2021-01-01'
AND att.report = 'Average Price'
AND att.product ILIKE '{good}%';
"""


def get_list_of_candidate_goods(good: str) -> List[str]:
"""Returns a list of goods that exist in the database.
The list of goods is represented as a string separated by '\n'."""
query = SQL_QUERY_TEMPLATE.format(good=good)
url = URL(
account=snowflake_account,
user=snowflake_user,
password=snowflake_password,
database="FINANCIAL__ECONOMIC_ESSENTIALS",
schema="CYBERSYN",
warehouse="COMPUTE_WH",
role=snowflake_role,
)

engine = create_engine(url)
try:
connection = engine.connect()
results = connection.execute(text(query))
finally:
connection.close()

return [f"{ix+1}. {str(el[0])}" for ix, el in enumerate(results)]
import snowflake_cybersyn_demo.workflows._db as db


class CandidateLookupEvent(Event):
Expand All @@ -74,7 +28,7 @@ async def retrieve_candidates_from_db(
) -> CandidateLookupEvent:
# Your workflow logic here
good = str(ev.get("good", ""))
candidates = get_list_of_candidate_goods(good=good)
candidates = db.get_list_of_candidate_goods(good=good)
return CandidateLookupEvent(candidates=candidates)

@step
Expand Down Expand Up @@ -108,9 +62,9 @@ async def human_input(self, ev: CandidateLookupEvent) -> HumanInputEvent:

@step
async def get_time_series_data(self, ev: HumanInputEvent) -> StopEvent:
timeseries_data_str = get_time_series_of_good(good=ev.selected_good)
timeseries_data_str = db.get_time_series_of_good(good=ev.selected_good)
# aggregation
aggregated_timeseries_data = perform_price_aggregation(
aggregated_timeseries_data = db.perform_price_aggregation(
timeseries_data_str
)
return StopEvent(result=aggregated_timeseries_data)
Expand Down

0 comments on commit 7a7378f

Please sign in to comment.