From fde0064fd6c76506569ca09d21343e51563add01 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 18 Oct 2024 19:10:31 +1300 Subject: [PATCH] Add Cloudflare KV caching --- agent.py | 88 +++++++++++++++++++++++------------------- graph.png | Bin 0 -> 8039 bytes main.py | 2 +- tools.py | 19 +++++++-- utils/snow_connect.py | 60 +++++++++++++++++++++++++--- 5 files changed, 120 insertions(+), 49 deletions(-) create mode 100644 graph.png diff --git a/agent.py b/agent.py index d3324fd..fa0f2a6 100644 --- a/agent.py +++ b/agent.py @@ -13,9 +13,10 @@ from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage -from template import TEMPLATE from tools import retriever_tool - +from tools import search, sql_executor_tool +from PIL import Image +from io import BytesIO @dataclass class MessagesState: @@ -32,39 +33,43 @@ class ModelConfig: base_url: Optional[str] = None -def create_agent(callback_handler: BaseCallbackHandler, model_name: str): - model_configurations = { - "gpt-4o": ModelConfig( - model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") - ), - "Gemini Flash 1.5 8B": ModelConfig( - model_name="google/gemini-flash-1.5-8b", - api_key=st.secrets["OPENROUTER_API_KEY"], - base_url="https://openrouter.ai/api/v1", - ), - "claude3-haiku": ModelConfig( - model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") - ), - "llama-3.2-3b": ModelConfig( - model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", - api_key=os.getenv("FIREWORKS_API_KEY"), - base_url="https://api.fireworks.ai/inference/v1", - ), - "llama-3.1-405b": ModelConfig( - model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", - api_key=os.getenv("FIREWORKS_API_KEY"), - base_url="https://api.fireworks.ai/inference/v1", - ), - } +model_configurations = { + "gpt-4o": ModelConfig( + model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") + ), + "Gemini Flash 1.5 8B": ModelConfig( + model_name="google/gemini-flash-1.5-8b", + api_key=st.secrets["OPENROUTER_API_KEY"], + base_url="https://openrouter.ai/api/v1", + ), + "claude3-haiku": ModelConfig( + model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") + ), + "llama-3.2-3b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), + "llama-3.1-405b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), +} +sys_msg = SystemMessage( + content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: + - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. + - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. + - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. + """ +) +tools = [retriever_tool, search, sql_executor_tool] + +def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph: config = model_configurations.get(model_name) if not config: raise ValueError(f"Unsupported model name: {model_name}") - sys_msg = SystemMessage( - content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. - Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code. - """ - ) llm = ( ChatOpenAI( @@ -73,6 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str): callbacks=[callback_handler], streaming=True, base_url=config.base_url, + temperature=0.01, ) if config.model_name != "claude-3-haiku-20240307" else ChatAnthropic( @@ -83,21 +89,25 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str): ) ) - tools = [retriever_tool] - llm_with_tools = llm.bind_tools(tools) - def reasoner(state: MessagesState): + def llm_agent(state: MessagesState): return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} - # Build the graph builder = StateGraph(MessagesState) - builder.add_node("reasoner", reasoner) + builder.add_node("llm_agent", llm_agent) builder.add_node("tools", ToolNode(tools)) - builder.add_edge(START, "reasoner") - builder.add_conditional_edges("reasoner", tools_condition) - builder.add_edge("tools", "reasoner") + builder.add_edge(START, "llm_agent") + builder.add_conditional_edges("llm_agent", tools_condition) + builder.add_edge("tools", "llm_agent") react_graph = builder.compile(checkpointer=memory) + # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() + # with open("graph.png", "wb") as f: + # f.write(png_data) + + # image = Image.open(BytesIO(png_data)) + # st.image(image, caption="React Graph") + return react_graph diff --git a/graph.png b/graph.png new file mode 100644 index 0000000000000000000000000000000000000000..dcb88cd09d3a1a3d86f8a5b8262ebe6adc87849f GIT binary patch literal 8039 zcmb7I1wd5W);>r}2$C|CfFP;B(4~@!f`Zh5AkEM)^bjHvN(#~?paK%Z&?BMJ-ObP) zLwEk;z4txed;fcXtaHxV-`Q*JwZC&_pS{ ze*XX9^FLJ+o0-E*u?}0|dN@*l9tKVVZw=U@I|SRZLSTi0KG{gPjdNzLuGG_fTS_F@EJfCiupJoxqe*n6xv zWCDQrJ^DQ1t-7RsBC~SoI&cv0_b(SY8g;!xFFs%mG&5Az%-f0s>eh1l$D#0ny8GKpw!o za^=^F6+G;TPlS(;hlhWakdT0g^eQPS$yE{(GV<#bWaO0OBqS8H6qHodG&D4%*XZbJ zsp+p%(@_5kf`f}~gNILyk55cZMnXpY|CY-i03{J{3=HAoumM*nad0VdE*k(Q?CZt_ zaB+SO#lHlffDjjt2_S@F#>KsTqR*XN{NKk2i)lE!#+K4^Jdl4K9hdR*7Fa~nED$-cg^daN)!9Gw z2w;1_$0H!b8o`v<58?_AF&@b;UVjMA6-r#Z>o*1OKht>MOKqI-6Q7EPUD(9QKN>kB z#J074IR=p8Vrfz0QUY?oMQ$eB6*g8jwktDb&Na7>*VUQa{2u89MNU3v8b5>$FEyKd z0e>Bqj@z-bf+bx7XF2lzP{&Oennb_SirmSaYYZ*dzg68XAV6}Uh8{vxObLreO27~0 zsxqgjHPsu^OZh9RYDJ8F-80tsPWzkk?{rhkn5TPrrYV$SbINo&=eb;}PX9&({Jx4_KUs2=}Y z)Yc9785v=?u2I@;uXfpebZ%6h&*wqgdh@L6dR9jkl*`%QVhVgNz|H1wx#h-?B3!yL zPLoN|(Z8?KnIdPb4wJI1KT4un&Z8VfeX+WdIs><&U94zC-@md=< zHRBH6b~KYbjuEZ;Y|0=e0Fy2P_M#`^(DZK`zVHls=nh5oA9-xiTYq4eo(PpjN$;qr zZF-)dXP2l}ZmAS9xMkL8kld9@yOA(Tzp1K%W@MMCm&igk ztqWViId(MyWgk_Nc3A6rkgSMQRdcJe6}KNzSd?uql?qSp zMxn{#G?|U`Rwmr*Fh2P%X%I{mgHjilAbHxy)5vVEhc=HTW+R=4s){|(_6%NBOV7#q z=l0XGlQl`3Kt>B?1aewG9auC$c=aUYoFwlbGQKYtVn^P5`{{kj!jcRFu*{rL(QHSh zt~=Qm%x%(?kp*_eJk1U-Gv1O^8D}Tya1s{@L^#B$3(SWf;t|ze+1>;-OVcFBT1_+? zAWS-(b<$$vl-SMRvW0ghsU$R`I$y#P`c#Se9q|Mzh+WRnmh_eV(PiOC4bnJ$=GF2= z{|>vH(Fp%f$OhYkvNOy|C1Ha5M&@~kn-GK1wGTas;W9~0^GMIfpAmk#nFl>3q!pjy z#NmbfSPdD!KR(k_zU zK;kPZH1u*hxu-l?NDD)%H_rvs9LBN-ErY``;Invwv-piM6wBRlkNMP%FwB0*hl|1G z?CC+D?ZqluGX+DJ-$V2-gs+jZ-s1qa!_eY%3uHSWZawK*yrB*Qp`;_1C(4aSf(}x= z5N{wLqj}R-YnfYbx+iTP&_Y_Cz4OF-8O^{rtuCc`xQx}SuN&t9nWzZF}VY9H{NL|lukQ&l?1$#As@{b5sW>NFNfA}a!Y&nd@m7OPuFXB?#t_r%Z>b#a6QTXU+V z(zerrcU?gN$Q>V6-vzEO9=spH7KVIm4GCG(xw@vx>36NT0sF4rJ#VPaPPk<{B-kBUzXA$j$IFA{D^1y;gFc0 z-14q(?^2JpU+mP(sj92}B%k>Sh1exfV+yZ6^^Hg!KaR2@3V9V0@;|8(bP>w}?q_x3 z``bo>^O%84Aazc1v%(#-8#$W6cgSCHd>?9kJR~wUh z>)J5GX84q^=B875ecLs{4mzZ_RKDEZpC{-^c@V?$Tb78yT(Gq@IoTAFsM$uV4;TjAM{*BhlHX|JM-2bmek#0?$ zh1+_8sCA!5`XRe_L?P87*v#RN z^x-eXcdEi{Rf2bTWO!)QE7g2X`?_R;$f0|$zdw5Y zXIIN9i!!Kk0?u3%vHFVi0_JnIs8-{EI#cnMJKV@H@Wwq3%-NX^B~ zwW*j(;N8Wf4N|r`zM=M;{)CnEsg9H%Mr*i$Mq}w}X_~xDO2=s3Pf>XZ`JahmsIGaG zksXR7(pQ;+@U|1<50@$XQ0~yIJKb{vHuh3;;Er!GGVgve(cbx1PIv2;tY>2e$ZXLV zkAfKP78=gv8;DjD87z3(bi`a5Q3n1ZnVJn9^>mOr_Hma=7d>@(VRfYV#R{4# z7u0OZj?h=rJG@|!ut-wUDhBcDv)A60ye9<}vq`zXc;ZjL*oAa4d{C0VPdw~psSmA6 zih@2aFR`HPUv$NZt7-X! zZ&SR=P5jyuaD9V%2*1>$#t^4D($P7X;bLd#;q*beZ(0da{D>zLbDHeN!v5?D| zSihlKqQ9i|>%6aA;>j1$#srz_^ra^h2Y@$%eYW||`}UR1)hrkzeeY<0Hj@LzoIu*! zvhA4r_8#W1SsC!m*F9jP^I~q*l}z7W7k9pQ)0(__p=5RmW`z+bJGK4 zAp;MafY86p#w0Smmq5#+T=^Zp5U-wC&P6WXY-c+2X48p;3};Kd^~}Wgoum<~+Y6N< zom@GqL>z~Q2z&MCkguF}a1TD~;wRKxfhAM&LbYY8QpsJj$GU{-IEDpIx5qCkhL+?d}DR@A&J>@REisyGlL4~N@6Ut7{o}-n zooDYvIGU=a$U_L`)22U~`h=?+7b|lNASu**Rl+;lluVZ8D{diS(udBdt`G{i@V5VP z`b|rXw|OS!ajC+LLM0M(<{TOUT=9(T&Z;N|cTb1g7<&CK`2cfd1pepM$=|4azjO1#`^%_CzGKBNhVTh_;r{KzRS$WGrz zXSYodx+A=?5WkSmLD+zx{p-1+0NY;F;%rdtLpyoxz~y<^xOtygmKkFf(odWZh#rP1 zm!Q&HM)&S1>d~0@ipb{{>qC6=X*rXMFa-vts4Arpr485^4qV~hb-+|t6z6Lr-u7r< zjKVgDx)sb|a`u|j{g!`9Q9|DhTKap2w6Bo>!ixs9-HMdHXfH;z^ld> z6@Iu}V;rZAcoiOek9z0NWwpD~Q7vykGQhy0Q<;se*@UTfX7A-4Hm6~`&vAMwM90s{ zsg}nNbL~_c_zJ?g!xDREd8>uxq5`{YP;TUhAJmxoq$$V5nMOh7AXz0;{MY^2Ts+77 zpqoLQK|^g#xdA{Kxn7%S9%hwG8@kL0*WUQ37uLH!Po?yi&>#5d+pYb&{Sg zYV5_Cx3IKX1MK(zkRumQ@%8T|fw#QHfnrFV7l@e#GbmCKPnI)SE1 z*YbUAA*Y$VPGi__6qA4$WIM%ixeGQ{A9Uk*PigT-hDcth3ex9Sdo+}FUEvZynDj{o zgChzYr6^_wD2P_8POfpw+9*zMLi}tjyqsk9M8{LnS?+d==>73E1J&Tb^v{i7;0G)5p;OOeA=T)bywg^$j-l{YS^pw*6nTfV zB+};H?a+?5DQS?1>sbZ!_Bt6_-+AB~<)YX6mMCaU4DT*48$w`Fl@ZC!jFBwW;K1N{ z45Vw``(|v>+TpjIR@NO8&{@Jbx@nJzP_u|5X3z;yP`Ag-sJUsAeCTIy2#Ty14L@F~ zMckQI53&R8_WKwb{V1B(uL+KE((00=3&L-Rgnfw_lwG%V`-E^7u19}Jqi?HJWVLmp zclTzjf_d8UDkMMmMuw&zt8QIe;@YG3)i{$(lV;8mXc zo+h}rczY*5OMh@^-r?7|iEtS3qL}IN6u`q~BNc|AwV8bA6T@*m0>cEWG*RzF2>kQQ zU2xQWCJbvtmWc<-J#)1cy8q$H%68yzuCsz-or2-lvxE-cY_SoZnlInE1dYBb9$YC} zMYM)Buflv@=Z)Fyz-UfB8NC#Uxb`z05xwr7FR65&J;T&I0URgTwLtDRfg{EA?f8JV z3(Y?_l+)y>BbdLx$iqvsYFL4ZoQd%Km|rNT#E0fld95(gIk5gy_g_c*=umMAO;&0ne<6Y(%p1lN>=e zj4@wxj9s?Ow*1&PZh75)(v08~z(6|QQ$G-tIIWa3yC^RP52?62@Ip2fnZ5kHzbnpa z27?}cXucVne`~y%xdZnl4mxdrqS-O)h4s1Cmtx6_3*uPtpw!a=3%rg)G*X(NDm%W^ zM&YHBQ32@c{~63X zNxEk`%>yw^9EnEWrEa!1dfE-%`A%k?IE_C%4;DQ1Edu%v_P)IRVeByfsl%}O>FZC& z#1%`$2TBYpH2qZ^4aN5hi9V-W9FEZSD2#Y21tB#>1dq_o1Dl!(S2;K zgApE+@RXwq{%K$TQI>}$)nW6-wAa<{z`cMvUIwLHm!Z+%bG zl_i#(D_YN;9MseI^}3JTZ?4mg5584o;7JH7YwaR~a=aqcP){^*sTrV$qpn7p(eJ*g zu#06&P0V3}xB0P~xe6`>(NXw5bA+NxdB?@GvNvnW;EcVPZp)Ir^8@f^e_j}XGhCe- z3lAQn(yls&~ZgYrny5W~UT(J`vDE9-_9k7e2~C>*T%#;;KFvd!-#S0$+kKtWLwuh9EPiH)QUBR()? z)fr9{0d3g!i%!v_eCw{4cg84@cf9Rj25$T#dxWmt?`K2}r^w5KaN?jOiDB!blAExp z(Q>YFvmeqs_DpSkqCoiBh3n584`@E`(QagaoMW-E$kQY!!IOR=>+)N#6mL4PtH58h z^$*)cW_M@iwh7yb^)n5RpE1aWrs+eZ@WFku&`%k-QydWjuoP1$N6h5I-P}OM<~48a z@TR9fr|dOUD^Gfc5(l;_vO`eF((9u)h;4$0t}Y$D=x;vPcE4b1f#YMc>B!mGtV7SA z!8FErIo4%se(0HoD)c95_+KR7f&zZE?Cz&L6@i`F-E&!0Pjzjtao|{&md!7_Nfj%V z)mL>%=5fAlsuD&x(j-h;{FJETW?2>zjlOoeD0tDAsL(4;(+Me%`|^Cv>pmsj;?>YW zqnI0!1B<8lI;F$&i@f_PwLi2|I>pSK=)T^4bvzx-s-A%QKgA3^$?BaOEAjT0zsPPHRDd5c+R`;y5tE+>0)(pA%o&8t|Gu4e;`96Sb z>^D`ePVbHXfKL;Y)zpS z%(ry48eN`zA&hq}RieRr1~&bPd@uMql?T6T zkPN2pWRy^u8K3jB2I|JYVUX0JV0CqG*ChO~cc z_pu~&v0?iD>ZAL>VNLxZo$cS;ln&^0Rimogz%QfE56_aye!#kiyw-aMJk?mqPGYMZ z=kAk%l%N=ym<7yE;=E-`c-DMF{Un5#3pUsgQv?zz=Tjm9y()a!^t2%IO=>%@!bpAO zr*+Bn{KAu2^T`i+yxD`jwAoEM+8LB|WcVZry%4{rBeX&Z()LP@5(?#CmDzvn*-xvr zL1~{4Yf*Vw^jT}wUfr|omF7D1HVTye4I$6$G4yCym-cRr_7QJyy24ia#I`T>Z%6}G z_`^Z9YhQmFq2xnV=jYydr1=|aeWofkb?TP5Bh}+kK$d-1Ti=rDRnya=iDQcb4e`l| g(*ss}G0l)rEv-R=$=q)?l$0-jCvv}$+~w$h0rjWn;{X5v literal 0 HcmV?d00001 diff --git a/main.py b/main.py index b75b3f5..ddbe5d1 100644 --- a/main.py +++ b/main.py @@ -204,7 +204,7 @@ def execute_sql(query, conn, retries=2): messages = [HumanMessage(content=user_input_content)] state = MessagesState(messages=messages) - result = react_graph.invoke(state, config=config) + result = react_graph.invoke(state, config=config, debug=True) if result["messages"]: assistant_message = callback_handler.final_message diff --git a/tools.py b/tools.py index 5b5a450..2be57ba 100644 --- a/tools.py +++ b/tools.py @@ -1,15 +1,15 @@ import streamlit as st -from langchain.prompts.prompt import PromptTemplate from supabase.client import Client, create_client -from langchain.tools.retriever import create_retriever_tool from langchain_openai import OpenAIEmbeddings from langchain_community.vectorstores import SupabaseVectorStore +from langchain.tools.retriever import create_retriever_tool +from langchain_community.tools import DuckDuckGoSearchRun +from utils.snow_connect import SnowflakeConnection supabase_url = st.secrets["SUPABASE_URL"] supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] supabase: Client = create_client(supabase_url, supabase_key) - embeddings = OpenAIEmbeddings( openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" ) @@ -20,9 +20,20 @@ query_name="v_match_documents", ) - retriever_tool = create_retriever_tool( vectorstore.as_retriever(), name="Database_Schema", description="Search for database schema details", ) + +search = DuckDuckGoSearchRun() + +def sql_executor_tool(query: str, use_cache: bool = True) -> str: + """ + Execute snowflake sql queries with optional caching. + """ + conn = SnowflakeConnection() + return conn.execute_query(query, use_cache) + +if __name__ == "__main__": + print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS")) diff --git a/utils/snow_connect.py b/utils/snow_connect.py index d0b396a..f525adc 100644 --- a/utils/snow_connect.py +++ b/utils/snow_connect.py @@ -1,12 +1,13 @@ from typing import Any, Dict - +import json +import requests import streamlit as st from snowflake.snowpark.session import Session class SnowflakeConnection: """ - This class is used to establish a connection to Snowflake. + This class is used to establish a connection to Snowflake and execute queries with optional caching. Attributes ---------- @@ -19,16 +20,24 @@ class SnowflakeConnection: ------- get_session() Establishes and returns the Snowflake connection session. - + execute_query(query: str, use_cache: bool = True) + Executes a Snowflake SQL query with optional caching. """ def __init__(self): self.connection_parameters = self._get_connection_parameters_from_env() self.session = None + self.cloudflare_account_id = st.secrets["CLOUDFLARE_ACCOUNT_ID"] + self.cloudflare_namespace_id = st.secrets["CLOUDFLARE_NAMESPACE_ID"] + self.cloudflare_api_token = st.secrets["CLOUDFLARE_API_TOKEN"] + self.headers = { + "Authorization": f"Bearer {self.cloudflare_api_token}", + "Content-Type": "application/json" + } @staticmethod def _get_connection_parameters_from_env() -> Dict[str, Any]: - connection_parameters = { + return { "account": st.secrets["ACCOUNT"], "user": st.secrets["USER_NAME"], "password": st.secrets["PASSWORD"], @@ -37,7 +46,6 @@ def _get_connection_parameters_from_env() -> Dict[str, Any]: "schema": st.secrets["SCHEMA"], "role": st.secrets["ROLE"], } - return connection_parameters def get_session(self): """ @@ -49,3 +57,45 @@ def get_session(self): self.session = Session.builder.configs(self.connection_parameters).create() self.session.sql_simplifier_enabled = True return self.session + + def _construct_kv_url(self, key: str) -> str: + return f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/storage/kv/namespaces/{self.cloudflare_namespace_id}/values/{key}" + + def get_from_cache(self, key: str) -> str: + url = self._construct_kv_url(key) + try: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + print("\n\n\nCache hit\n\n\n") + return response.text + except requests.exceptions.RequestException as e: + print(f"Cache miss or error: {e}") + return None + + def set_to_cache(self, key: str, value: str) -> None: + url = self._construct_kv_url(key) + serialized_value = json.dumps(value) + try: + response = requests.put(url, headers=self.headers, data=serialized_value) + response.raise_for_status() + print("Cache set successfully") + except requests.exceptions.RequestException as e: + print(f"Failed to set cache: {e}") + + def execute_query(self, query: str, use_cache: bool = True) -> str: + """ + Execute a Snowflake SQL query with optional caching. + """ + if use_cache: + cached_response = self.get_from_cache(query) + if cached_response: + return json.loads(cached_response) + + session = self.get_session() + result = session.sql(query).collect() + result_list = [row.as_dict() for row in result] + + if use_cache: + self.set_to_cache(query, result_list) + + return result_list