-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add financial_and_economic_essentials agent
- Loading branch information
Showing
7 changed files
with
829 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ docker-compose.local.yml | |
pyproject.local.toml | ||
__pycache__ | ||
data | ||
notebooks |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
78 changes: 78 additions & 0 deletions
78
...lake_cybersyn_demo/agent_services/financial_and_economic_essentials/goods_getter_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue | ||
from llama_index.agent.openai import OpenAIAgent | ||
from llama_index.core.tools import FunctionTool | ||
from llama_index.llms.openai import OpenAI | ||
from snowflake.sqlalchemy import URL | ||
from sqlalchemy import create_engine, text | ||
|
||
from snowflake_cybersyn_demo.utils import load_from_env | ||
|
||
message_queue_host = load_from_env("RABBITMQ_HOST") | ||
message_queue_port = load_from_env("RABBITMQ_NODE_PORT") | ||
message_queue_username = load_from_env("RABBITMQ_DEFAULT_USER") | ||
message_queue_password = load_from_env("RABBITMQ_DEFAULT_PASS") | ||
control_plane_host = load_from_env("CONTROL_PLANE_HOST") | ||
control_plane_port = load_from_env("CONTROL_PLANE_PORT") | ||
funny_agent_host = load_from_env("FUNNY_AGENT_HOST") | ||
funny_agent_port = load_from_env("FUNNY_AGENT_PORT") | ||
localhost = load_from_env("LOCALHOST") | ||
|
||
|
||
# create agent server | ||
message_queue = RabbitMQMessageQueue( | ||
url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" | ||
) | ||
|
||
SQL_QUERY_TEMPLATE = """ | ||
SELECT DISTINCT att.product, | ||
FROM cybersyn.bureau_of_labor_statistics_price_timeseries AS ts | ||
JOIN cybersyn.bureau_of_labor_statistics_price_attributes AS att | ||
ON (ts.variable = att.variable) | ||
WHERE ts.date >= '2021-01-01' | ||
AND att.report = 'Average Price' | ||
AND att.product ILIKE '{good}%'; | ||
""" | ||
|
||
AGENT_SYSTEM_PROMPT = """ | ||
For a given query about a good in the database, your job is to first find | ||
if the good exists in the database. Return the list of goods in the database | ||
that potentially match the object of the users query. | ||
""" | ||
|
||
|
||
def get_list_of_candidate_goods(good: str) -> str: | ||
"""Returns a list of goods that exist in the database. | ||
The list of goods is represented as a string separated by '\n'.""" | ||
query = SQL_QUERY_TEMPLATE.format(good=good) | ||
url = URL( | ||
account="AZXOMEC-NZB11223", | ||
user="NERDAILLAMAINDEX", | ||
password="b307gJ5YzR8k", | ||
database="FINANCIAL__ECONOMIC_ESSENTIALS", | ||
schema="CYBERSYN", | ||
warehouse="COMPUTE_WH", | ||
role="ACCOUNTADMIN", | ||
) | ||
|
||
engine = create_engine(url) | ||
try: | ||
connection = engine.connect() | ||
results = connection.execute(text(query)) | ||
finally: | ||
connection.close() | ||
|
||
# process | ||
results = [str(el[0]) for el in results] | ||
results_str = "\n".join(results) | ||
|
||
return results_str | ||
|
||
|
||
goods_getter_tool = FunctionTool.from_defaults(fn=get_list_of_candidate_goods) | ||
agent = OpenAIAgent.from_tools( | ||
[goods_getter_tool], | ||
system_prompt=AGENT_SYSTEM_PROMPT, | ||
llm=OpenAI(model="gpt-4o-mini"), | ||
verbose=True, | ||
) |
105 changes: 105 additions & 0 deletions
105
...ybersyn_demo/agent_services/financial_and_economic_essentials/time_series_getter_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import json | ||
|
||
from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue | ||
from llama_index.agent.openai import OpenAIAgent | ||
from llama_index.core.tools import FunctionTool | ||
from llama_index.llms.openai import OpenAI | ||
from snowflake.sqlalchemy import URL | ||
from sqlalchemy import create_engine, text | ||
|
||
from snowflake_cybersyn_demo.utils import load_from_env | ||
|
||
message_queue_host = load_from_env("RABBITMQ_HOST") | ||
message_queue_port = load_from_env("RABBITMQ_NODE_PORT") | ||
message_queue_username = load_from_env("RABBITMQ_DEFAULT_USER") | ||
message_queue_password = load_from_env("RABBITMQ_DEFAULT_PASS") | ||
control_plane_host = load_from_env("CONTROL_PLANE_HOST") | ||
control_plane_port = load_from_env("CONTROL_PLANE_PORT") | ||
funny_agent_host = load_from_env("FUNNY_AGENT_HOST") | ||
funny_agent_port = load_from_env("FUNNY_AGENT_PORT") | ||
localhost = load_from_env("LOCALHOST") | ||
|
||
|
||
# create agent server | ||
message_queue = RabbitMQMessageQueue( | ||
url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" | ||
) | ||
|
||
AGENT_SYSTEM_PROMPT = """ | ||
Query the database to return timeseries data of user-specified good. | ||
Use the tool to return the time series data as a JSON with the folowing format: | ||
{{ | ||
[ | ||
{{ | ||
"good": ..., | ||
"date": ..., | ||
"price": ... | ||
}}, | ||
{{ | ||
"good": ..., | ||
"date": ..., | ||
"price": ... | ||
}}, | ||
... | ||
] | ||
}} | ||
Don't return the output as markdown code. Don't modify the tool output. Return | ||
strictly the tool ouput. | ||
""" | ||
|
||
SQL_QUERY_TEMPLATE = """ | ||
SELECT ts.date, | ||
att.variable_name, | ||
ts.value | ||
FROM cybersyn.bureau_of_labor_statistics_price_timeseries AS ts | ||
JOIN cybersyn.bureau_of_labor_statistics_price_attributes AS att | ||
ON (ts.variable = att.variable) | ||
WHERE ts.date >= '2021-01-01' | ||
AND att.report = 'Average Price' | ||
AND att.product ILIKE '{good}%' | ||
ORDER BY date; | ||
""" | ||
|
||
|
||
def get_time_series_of_good(good: str) -> str: | ||
"""Create a time series of the average price paid for a good nationwide starting in 2021.""" | ||
query = SQL_QUERY_TEMPLATE.format(good=good) | ||
url = URL( | ||
account="AZXOMEC-NZB11223", | ||
user="NERDAILLAMAINDEX", | ||
password="b307gJ5YzR8k", | ||
database="FINANCIAL__ECONOMIC_ESSENTIALS", | ||
schema="CYBERSYN", | ||
warehouse="COMPUTE_WH", | ||
role="ACCOUNTADMIN", | ||
) | ||
|
||
engine = create_engine(url) | ||
try: | ||
connection = engine.connect() | ||
results = connection.execute(text(query)) | ||
finally: | ||
connection.close() | ||
|
||
# process | ||
results = [ | ||
{"good": str(el[1]), "date": str(el[0]), "price": str(el[2])} | ||
for el in results | ||
] | ||
results_str = json.dumps(results, indent=4) | ||
|
||
return results_str | ||
|
||
|
||
goods_getter_tool = FunctionTool.from_defaults( | ||
fn=get_time_series_of_good, return_direct=True | ||
) | ||
agent = OpenAIAgent.from_tools( | ||
[goods_getter_tool], | ||
system_prompt=AGENT_SYSTEM_PROMPT, | ||
llm=OpenAI(model="gpt-3.5-turbo"), | ||
verbose=True, | ||
) |
77 changes: 77 additions & 0 deletions
77
...rsyn_demo/agent_services/financial_and_economic_essentials/time_series_processer_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import json | ||
from typing import Dict, List | ||
|
||
from llama_agents.message_queues.rabbitmq import RabbitMQMessageQueue | ||
from llama_index.agent.openai import OpenAIAgent | ||
from llama_index.core.tools import FunctionTool | ||
from llama_index.llms.openai import OpenAI | ||
|
||
from snowflake_cybersyn_demo.utils import load_from_env | ||
|
||
message_queue_host = load_from_env("RABBITMQ_HOST") | ||
message_queue_port = load_from_env("RABBITMQ_NODE_PORT") | ||
message_queue_username = load_from_env("RABBITMQ_DEFAULT_USER") | ||
message_queue_password = load_from_env("RABBITMQ_DEFAULT_PASS") | ||
control_plane_host = load_from_env("CONTROL_PLANE_HOST") | ||
control_plane_port = load_from_env("CONTROL_PLANE_PORT") | ||
funny_agent_host = load_from_env("FUNNY_AGENT_HOST") | ||
funny_agent_port = load_from_env("FUNNY_AGENT_PORT") | ||
localhost = load_from_env("LOCALHOST") | ||
|
||
|
||
# create agent server | ||
message_queue = RabbitMQMessageQueue( | ||
url=f"amqp://{message_queue_username}:{message_queue_password}@{message_queue_host}:{message_queue_port}/" | ||
) | ||
|
||
AGENT_SYSTEM_PROMPT = """ | ||
Perform price aggregation on the time series data to ensure that each date only | ||
has one associated price. | ||
Return the time series data as a JSON with the folowing format: | ||
{{ | ||
[ | ||
{{ | ||
"good": ..., | ||
"date": ..., | ||
"price": ... | ||
}} | ||
] | ||
}} | ||
Don't return the output as markdown code. | ||
""" | ||
|
||
|
||
def perform_price_aggregation(json_str: str) -> str: | ||
"""Perform price aggregation on the time series data.""" | ||
timeseries_data = json.loads(json_str) | ||
good = timeseries_data[0]["good"] | ||
|
||
new_time_series_data: Dict[str, List[float]] = {} | ||
for el in timeseries_data: | ||
date = el["date"] | ||
price = el["price"] | ||
if date in new_time_series_data: | ||
new_time_series_data[date].append(price) | ||
else: | ||
new_time_series_data[date] = [price] | ||
|
||
reduced_time_series_data = [ | ||
{"good": good, "date": date, "price": sum(prices) / len(prices)} | ||
for date, prices in new_time_series_data.items() | ||
] | ||
|
||
return json.dumps(reduced_time_series_data, indent=4) | ||
|
||
|
||
price_aggregation_tool = FunctionTool.from_defaults( | ||
fn=perform_price_aggregation, return_direct=True | ||
) | ||
agent = OpenAIAgent.from_tools( | ||
[price_aggregation_tool], | ||
system_prompt=AGENT_SYSTEM_PROMPT, | ||
llm=OpenAI(model="gpt-3.5-turbo"), | ||
verbose=True, | ||
) |