Skip to content

Commit

Permalink
Add Cloudflare KV caching
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Oct 18, 2024
1 parent 3c6472c commit fde0064
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 49 deletions.
88 changes: 49 additions & 39 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage

from template import TEMPLATE
from tools import retriever_tool

from tools import search, sql_executor_tool
from PIL import Image
from io import BytesIO

@dataclass
class MessagesState:
Expand All @@ -32,39 +33,43 @@ class ModelConfig:
base_url: Optional[str] = None


def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
model_configurations = {
"gpt-4o": ModelConfig(
model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
),
"Gemini Flash 1.5 8B": ModelConfig(
model_name="google/gemini-flash-1.5-8b",
api_key=st.secrets["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
),
"claude3-haiku": ModelConfig(
model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
),
"llama-3.2-3b": ModelConfig(
model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
api_key=os.getenv("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
),
"llama-3.1-405b": ModelConfig(
model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
api_key=os.getenv("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
),
}
model_configurations = {
"gpt-4o": ModelConfig(
model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
),
"Gemini Flash 1.5 8B": ModelConfig(
model_name="google/gemini-flash-1.5-8b",
api_key=st.secrets["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
),
"claude3-haiku": ModelConfig(
model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
),
"llama-3.2-3b": ModelConfig(
model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
api_key=os.getenv("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
),
"llama-3.1-405b": ModelConfig(
model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
api_key=os.getenv("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
),
}
sys_msg = SystemMessage(
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools:
- Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code.
- Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code.
- Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code.
"""
)
tools = [retriever_tool, search, sql_executor_tool]

def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph:
config = model_configurations.get(model_name)
if not config:
raise ValueError(f"Unsupported model name: {model_name}")

sys_msg = SystemMessage(
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code.
"""
)

llm = (
ChatOpenAI(
Expand All @@ -73,6 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
callbacks=[callback_handler],
streaming=True,
base_url=config.base_url,
temperature=0.01,
)
if config.model_name != "claude-3-haiku-20240307"
else ChatAnthropic(
Expand All @@ -83,21 +89,25 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
)
)

tools = [retriever_tool]

llm_with_tools = llm.bind_tools(tools)

def reasoner(state: MessagesState):
def llm_agent(state: MessagesState):
return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}

# Build the graph
builder = StateGraph(MessagesState)
builder.add_node("reasoner", reasoner)
builder.add_node("llm_agent", llm_agent)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "reasoner")
builder.add_conditional_edges("reasoner", tools_condition)
builder.add_edge("tools", "reasoner")
builder.add_edge(START, "llm_agent")
builder.add_conditional_edges("llm_agent", tools_condition)
builder.add_edge("tools", "llm_agent")

react_graph = builder.compile(checkpointer=memory)

# png_data = react_graph.get_graph(xray=True).draw_mermaid_png()
# with open("graph.png", "wb") as f:
# f.write(png_data)

# image = Image.open(BytesIO(png_data))
# st.image(image, caption="React Graph")

return react_graph
Binary file added graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def execute_sql(query, conn, retries=2):
messages = [HumanMessage(content=user_input_content)]

state = MessagesState(messages=messages)
result = react_graph.invoke(state, config=config)
result = react_graph.invoke(state, config=config, debug=True)

if result["messages"]:
assistant_message = callback_handler.final_message
Expand Down
19 changes: 15 additions & 4 deletions tools.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import streamlit as st
from langchain.prompts.prompt import PromptTemplate
from supabase.client import Client, create_client
from langchain.tools.retriever import create_retriever_tool
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.tools.retriever import create_retriever_tool
from langchain_community.tools import DuckDuckGoSearchRun
from utils.snow_connect import SnowflakeConnection

supabase_url = st.secrets["SUPABASE_URL"]
supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
supabase: Client = create_client(supabase_url, supabase_key)


embeddings = OpenAIEmbeddings(
openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
)
Expand All @@ -20,9 +20,20 @@
query_name="v_match_documents",
)


retriever_tool = create_retriever_tool(
vectorstore.as_retriever(),
name="Database_Schema",
description="Search for database schema details",
)

search = DuckDuckGoSearchRun()

def sql_executor_tool(query: str, use_cache: bool = True) -> str:
"""
Execute snowflake sql queries with optional caching.
"""
conn = SnowflakeConnection()
return conn.execute_query(query, use_cache)

if __name__ == "__main__":
print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS"))
60 changes: 55 additions & 5 deletions utils/snow_connect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Dict

import json
import requests
import streamlit as st
from snowflake.snowpark.session import Session


class SnowflakeConnection:
"""
This class is used to establish a connection to Snowflake.
This class is used to establish a connection to Snowflake and execute queries with optional caching.
Attributes
----------
Expand All @@ -19,16 +20,24 @@ class SnowflakeConnection:
-------
get_session()
Establishes and returns the Snowflake connection session.
execute_query(query: str, use_cache: bool = True)
Executes a Snowflake SQL query with optional caching.
"""

def __init__(self):
self.connection_parameters = self._get_connection_parameters_from_env()
self.session = None
self.cloudflare_account_id = st.secrets["CLOUDFLARE_ACCOUNT_ID"]
self.cloudflare_namespace_id = st.secrets["CLOUDFLARE_NAMESPACE_ID"]
self.cloudflare_api_token = st.secrets["CLOUDFLARE_API_TOKEN"]
self.headers = {
"Authorization": f"Bearer {self.cloudflare_api_token}",
"Content-Type": "application/json"
}

@staticmethod
def _get_connection_parameters_from_env() -> Dict[str, Any]:
connection_parameters = {
return {
"account": st.secrets["ACCOUNT"],
"user": st.secrets["USER_NAME"],
"password": st.secrets["PASSWORD"],
Expand All @@ -37,7 +46,6 @@ def _get_connection_parameters_from_env() -> Dict[str, Any]:
"schema": st.secrets["SCHEMA"],
"role": st.secrets["ROLE"],
}
return connection_parameters

def get_session(self):
"""
Expand All @@ -49,3 +57,45 @@ def get_session(self):
self.session = Session.builder.configs(self.connection_parameters).create()
self.session.sql_simplifier_enabled = True
return self.session

def _construct_kv_url(self, key: str) -> str:
return f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/storage/kv/namespaces/{self.cloudflare_namespace_id}/values/{key}"

def get_from_cache(self, key: str) -> str:
url = self._construct_kv_url(key)
try:
response = requests.get(url, headers=self.headers)
response.raise_for_status()
print("\n\n\nCache hit\n\n\n")
return response.text
except requests.exceptions.RequestException as e:
print(f"Cache miss or error: {e}")
return None

def set_to_cache(self, key: str, value: str) -> None:
url = self._construct_kv_url(key)
serialized_value = json.dumps(value)
try:
response = requests.put(url, headers=self.headers, data=serialized_value)
response.raise_for_status()
print("Cache set successfully")
except requests.exceptions.RequestException as e:
print(f"Failed to set cache: {e}")

def execute_query(self, query: str, use_cache: bool = True) -> str:
"""
Execute a Snowflake SQL query with optional caching.
"""
if use_cache:
cached_response = self.get_from_cache(query)
if cached_response:
return json.loads(cached_response)

session = self.get_session()
result = session.sql(query).collect()
result_list = [row.as_dict() for row in result]

if use_cache:
self.set_to_cache(query, result_list)

return result_list

0 comments on commit fde0064

Please sign in to comment.